Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 13 Sep 2024 07:57:51 +0000 (09:57 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 13 Sep 2024 07:57:51 +0000 (09:57 +0200)
attae.py
main.py

index 069772b..bc90ed0 100755 (executable)
--- a/attae.py
+++ b/attae.py
@@ -45,19 +45,17 @@ class WithResidual(nn.Module):
 ######################################################################
 
 
-class vanilla_attention(q, k, v):
+def vanilla_attention(q, k, v):
     a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
     a = a.softmax(dim=3)
     y = torch.einsum("nhts,nhsd->nhtd", a, v)
-
-    # y = flex_attention(q, k, v, score_mod=noop)
-
     y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
-
     return y
 
 
-vanilla_attention = torch.compille(vanilla_attention)
+vanilla_attention = torch.compile(vanilla_attention)
+
+# y = flex_attention(q, k, v, score_mod=noop)
 
 
 class MHAttention(nn.Module):
@@ -93,7 +91,7 @@ class MHAttention(nn.Module):
         def noop(score, b, h, q_idx, kv_idx):
             return score
 
-        y = vanilla_attention(q, k, v, score_mod=noop)
+        y = vanilla_attention(q, k, v)
         # y = flex_attention(q, k, v, score_mod=noop)
 
         y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
@@ -163,7 +161,6 @@ class AttentionAE(nn.Module):
                     m.weight.fill_(1.0)
 
     def forward(self, x):
-        x = 2 * x[:, :, 0] + x[:, :, 1]
         x = self.embedding(x)
         x = self.positional_encoding(x)
         x = self.trunk(x)
diff --git a/main.py b/main.py
index 63cd377..0fea318 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -55,7 +55,7 @@ parser.add_argument("--inference_batch_size", type=int, default=25)
 
 parser.add_argument("--nb_train_samples", type=int, default=25000)
 
-parser.add_argument("--nb_test_samples", type=int, default=1000)
+parser.add_argument("--nb_test_samples", type=int, default=10000)
 
 parser.add_argument("--nb_train_alien_samples", type=int, default=0)
 
@@ -1388,9 +1388,26 @@ def multithread_execution(fun, arguments):
 
 ######################################################################
 
-for n_epoch in range(current_epoch, args.nb_epochs):
-    start_time = time.perf_counter()
 
+def save_models(models, suffix=""):
+    if suffix is not "":
+        suffix = "_" + suffix
+    for model in models:
+        filename = f"ae_{model.id:03d}{suffix}.pth"
+        torch.save(
+            {
+                "state_dict": model.state_dict(),
+                "optimizer_state_dict": model.optimizer.state_dict(),
+                "test_accuracy": model.test_accuracy,
+            },
+            os.path.join(args.result_dir, filename),
+        )
+        log_string(f"wrote {filename}")
+
+
+######################################################################
+
+for n_epoch in range(current_epoch, args.nb_epochs):
     state = {
         "current_epoch": n_epoch,
         "c_quizzes": c_quizzes,
@@ -1414,46 +1431,37 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         and time_train >= time_c_quizzes
     ):
         if c_quizzes is None:
-            for model in models:
-                filename = f"ae_{model.id:03d}_naive.pth"
-                torch.save(
-                    {
-                        "state_dict": model.state_dict(),
-                        "optimizer_state_dict": model.optimizer.state_dict(),
-                        "test_accuracy": model.test_accuracy,
-                    },
-                    os.path.join(args.result_dir, filename),
-                )
-                log_string(f"wrote {filename}")
-
-        # --------------------------------------------------------------------
+            save_models(models, "naive")
 
         last_n_epoch_c_quizzes = n_epoch
         nb_gpus = len(gpus)
         nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus
 
-        # --------------------------------------------------------------------
+        start_time = time.perf_counter()
 
         c_quizzes, agreements = multithread_execution(
             generate_ae_c_quizzes,
             [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus],
         )
 
-        # --------------------------------------------------------------------
-
-        filename = f"culture_c_quiz_{n_epoch:04d}.png"
         save_c_quizzes_with_scores(
-            models, c_quizzes[:256], filename, solvable_only=False
+            models,
+            c_quizzes[:256],
+            f"culture_c_quiz_{n_epoch:04d}.png",
+            solvable_only=False,
         )
 
-        filename = f"culture_c_quiz_{n_epoch:04d}_solvable.png"
         save_c_quizzes_with_scores(
-            models, c_quizzes[:256], filename, solvable_only=True
+            models,
+            c_quizzes[:256],
+            f"culture_c_quiz_{n_epoch:04d}_solvable.png",
+            solvable_only=True,
         )
 
         log_string(f"generated_c_quizzes {c_quizzes.size()=}")
 
         time_train = 0
+
         for model in models:
             model.test_accuracy = 0
 
@@ -1467,6 +1475,8 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
     weakest_models = ranked_models[: len(gpus)]
 
+    start_time = time.perf_counter()
+
     multithread_execution(
         one_ae_epoch,
         [
@@ -1485,19 +1495,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
     # --------------------------------------------------------------------
 
-    for model in models:
-        filename = f"ae_{model.id:03d}.pth"
-        torch.save(
-            {
-                "state_dict": model.state_dict(),
-                "optimizer_state_dict": model.optimizer.state_dict(),
-                "test_accuracy": model.test_accuracy,
-            },
-            os.path.join(args.result_dir, filename),
-        )
-        log_string(f"wrote {filename}")
-
-    # --------------------------------------------------------------------
+    save_models(models)
 
     duration = time.perf_counter() - start_time
     str_duration = ""