Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 9 Oct 2024 15:42:51 +0000 (17:42 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 9 Oct 2024 15:42:51 +0000 (17:42 +0200)
attae.py
main.py

index d0f4e6b..5d9da2e 100755 (executable)
--- a/attae.py
+++ b/attae.py
@@ -24,7 +24,8 @@ class VaswaniPositionalEncoding(nn.Module):
         t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None]
         j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :]
         k = j % 2  # works with float, weird
-        pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k)
+        period = self.len_max ** ((j - k) / x.size(2))
+        pe = torch.sin(t / period + (math.pi / 2) * k)
         y = x + pe
         return y
 
@@ -330,27 +331,26 @@ class Reasoning(nn.Module):
 
     def forward(self, x_q):
         T, S = x_q.size(1), self.x_star.size(0)
-        nb, dim = x_q.size(0), x_q.size(2)
-        x_star = self.x_star.unsqueeze(0).expand(nb, -1, -1)
+        nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks
+
+        x_star = self.x_star.reshape(1, S, dim).expand(nb, S, dim)
 
         x = torch.cat([x_star, x_q], dim=1)
         x = self.trunk_joint(x)
 
-        f, x = x[:, : x_star.size(1), :], x[:, x_star.size(1) :, :]
+        f, x = x[:, :S, :], x[:, S:, :]
 
         if hasattr(self, "forced_f") and self.forced_f is not None:
             f = self.forced_f.clone()
 
         self.pred_f = f.clone()
 
-        x = x.reshape(nb, self.nb_chunks, T // self.nb_chunks, dim)
-        f = f.unsqueeze(1).expand(nb, self.nb_chunks, S, dim)
-        x = x.reshape(nb * self.nb_chunks, T // self.nb_chunks, dim)
-        f = f.reshape(nb * self.nb_chunks, S, dim)
+        x = x.reshape(nb * nc, T // nc, dim)
+        f = f.repeat(nc, 1, 1)
         x = torch.cat([f, x], dim=1)
         x = self.trunk_marginal(x)
 
-        x = x[:, x_star.size(1) :, :]
+        x = x[:, S:, :]
         x = x.reshape(nb, T, dim)
 
         return x
diff --git a/main.py b/main.py
index 0bb3212..618a62e 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -929,8 +929,23 @@ log_string(f"vocabulary_size {vocabulary_size}")
 ######################################################################
 
 if args.test == "aebn":
-    model = new_model(0)
-    # f = model.trunk[len(model.trunk) // 2 :]
+    model = attae.AttentionAE(
+        vocabulary_size_in=vocabulary_size * 2,
+        vocabulary_size_out=vocabulary_size,
+        dim_model=args.dim_model,
+        dim_keys=args.dim_keys,
+        dim_hidden=args.dim_hidden,
+        nb_heads=args.nb_heads,
+        nb_blocks=args.nb_blocks,
+        dropout=args.dropout,
+        len_max=1e4,
+    )
+
+    model.id = 0
+    model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+    model.test_accuracy = 0.0
+    model.nb_epochs = 0
+
     model.trunk = attae.Reasoning(
         nb_f_tokens=25,
         nb_chunks=2,