Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 10 Oct 2024 08:12:11 +0000 (10:12 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 10 Oct 2024 08:12:11 +0000 (10:12 +0200)
attae.py
main.py

index 3eb6c4e..0d36a33 100755 (executable)
--- a/attae.py
+++ b/attae.py
@@ -54,6 +54,27 @@ class BlockRandomPositionalEncoding(nn.Module):
 ######################################################################
 
 
+class AdHocPositionalEncoding(nn.Module):
+    def __init__(self, dim_model, value, trainable=False):
+        super().__init__()
+        if trainable:
+            self.value = nn.Parameter(value.clone())
+        else:
+            self.register_buffer("value", value.clone())
+        self.fc = nn.Linear(
+            in_features=value.size(-1) + dim_model, out_features=dim_model
+        )
+
+    def forward(self, x):
+        value = self.value[None, :, :].repeat(x.size(0), 1, 1)
+        x = torch.cat([value, x], dim=2)
+        y = self.fc(x)
+        return y
+
+
+######################################################################
+
+
 class WithResidual(nn.Module):
     def __init__(self, *f):
         super().__init__()
diff --git a/main.py b/main.py
index ba5c6e2..d5c1c5c 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -589,7 +589,7 @@ def save_inference_images(model, n_epoch, c_quizzes, c_quiz_multiplier, local_de
 
     problem.save_quizzes_as_image(
         args.result_dir,
-        f"culture_prediction_{n_epoch}_{model.id}.png",
+        f"culture_prediction_{n_epoch:04d}_{model.id:02d}.png",
         quizzes=result[:128],
         predicted_parts=predicted_parts[:128],
         correct_parts=correct_parts[:128],
@@ -601,7 +601,7 @@ def save_inference_images(model, n_epoch, c_quizzes, c_quiz_multiplier, local_de
 
     problem.save_quizzes_as_image(
         args.result_dir,
-        f"culture_generation_{n_epoch}_{model.id}.png",
+        f"culture_generation_{n_epoch:04d}_{model.id:02d}.png",
         quizzes=result[:128],
     )
 
@@ -927,12 +927,19 @@ if args.test == "aebn":
         len_max=1e4,
     )
 
-    model.positional_encoding = attae.BlockRandomPositionalEncoding(
-        args.dim_model, 100, 4
-    )
+    # model.positional_encoding = attae.BlockRandomPositionalEncoding(
+    # args.dim_model, 100, 4
+    # )
+
+    i = torch.arange(400)[:, None]
+    k = [2**k for k in range(4)] + [10 * 2**k for k in range(4)] + [100, 200]
+    k = torch.tensor(k)[None, :]
+    pe = (i // k) % 2
+
+    model.positional_encoding = attae.AdHocPositionalEncoding(args.dim_model, pe)
 
     model.trunk = attae.Reasoning(
-        nb_f_tokens=25,
+        nb_f_tokens=8,
         nb_chunks=2,
         dim_model=args.dim_model,
         dim_qk=args.dim_keys,