Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 8 Sep 2024 10:31:09 +0000 (12:31 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 8 Sep 2024 10:31:09 +0000 (12:31 +0200)
attae.py
main.py

index 7bd4a44..e9e4bff 100755 (executable)
--- a/attae.py
+++ b/attae.py
@@ -102,7 +102,7 @@ class AttentionAE(nn.Module):
         assert dim_model % nb_heads == 0
 
         self.embedding = nn.Sequential(
-            nn.Embedding(vocabulary_size, dim_model),
+            nn.Embedding(2 * vocabulary_size, dim_model),
             nn.Dropout(dropout),
         )
 
@@ -143,7 +143,8 @@ class AttentionAE(nn.Module):
                     m.bias.zero_()
                     m.weight.fill_(1.0)
 
-    def forward(self, x, mask=None):
+    def forward(self, x):
+        x = 2 * x[:, :, 0] + x[:, :, 1]
         x = self.embedding(x)
         x = self.positional_encoding(x)
         x = self.trunk(x)
diff --git a/main.py b/main.py
index d90a3df..9285337 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -999,8 +999,8 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi
 models = []
 
 for i in range(args.nb_models):
-    model = MyAttentionAE(
-        # model = attae.AttentionAE(
+    model = MyAttentionAE(
+    model = attae.AttentionAE(
         vocabulary_size=vocabulary_size,
         dim_model=args.dim_model,
         dim_keys=args.dim_keys,
@@ -1338,6 +1338,9 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     else:
         log_string(f"nb_c_quizzes {c_quizzes.size(0)}")
 
+    # one_ae_epoch(model, quiz_machine, n_epoch, None)
+    # exit(0)
+
     # --------------------------------------------------------------------
 
     ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))