Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 12 Oct 2024 07:39:36 +0000 (09:39 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 12 Oct 2024 07:39:36 +0000 (09:39 +0200)
attae.py
main.py

index bb97ed4..94da984 100755 (executable)
--- a/attae.py
+++ b/attae.py
@@ -383,6 +383,45 @@ class Reasoning(nn.Module):
             residual_masker=residual_masker,
         )
 
+        self.mha_A = MHAttention(
+            dim_model=dim_model,
+            dim_qk=dim_qk,
+            dim_v=dim_model // nb_heads,
+            nb_heads=nb_heads,
+            attention=vanilla_attention,
+            attention_dropout=attention_dropout,
+        )
+
+        self.mha_B = MHAttention(
+            dim_model=dim_model,
+            dim_qk=dim_qk,
+            dim_v=dim_model // nb_heads,
+            nb_heads=nb_heads,
+            attention=vanilla_attention,
+            attention_dropout=attention_dropout,
+        )
+
+    def forward_AB(self, x_q):
+        T, S = x_q.size(1), self.x_star.size(0)
+        nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks
+
+        x = x_q
+        x = x.reshape(nb, nc, T // nc, dim).reshape(nb * nc, T // nc, dim)
+        x = self.trunk_A(x)
+        f = self.x_star.reshape(1, S, dim).expand(nb * nc, S, dim)
+        f = self.mha_A(f, x)
+
+        k = torch.arange(nb, device=x_q.device)
+        u = f[k * 2, :]
+        f[k * 2, :] = f[k * 2 + 1, :]
+        f[k * 2 + 1, :] = u
+
+        f = self.mha_B(x, f)
+        x = self.trunk_B(x)
+        x = x.reshape(nb, nc, T // nc, dim).reshape(nb, T, dim)
+
+        return x
+
     def forward(self, x_q):
         T, S = x_q.size(1), self.x_star.size(0)
         nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks
diff --git a/main.py b/main.py
index 3b1caa9..c8d6f10 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -567,6 +567,9 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device):
             if nb_samples % args.batch_size == 0:
                 model.optimizer.step()
 
+    if train:
+        model.nb_epochs += 1
+
     log_string(f"{label}_loss {n_epoch} model {model.id} {acc_loss/nb_samples}")
 
 
@@ -941,7 +944,7 @@ if args.test == "aebn":
         pe,  # trainable=True
     )
 
-    nb_f_tokens = 100
+    nb_f_tokens = 200
 
     def no_f_residual(x):
         m = x.new_full((1, x.size(1), 1), 1.0)
@@ -964,24 +967,39 @@ if args.test == "aebn":
     model.test_accuracy = 0.0
     model.nb_epochs = 0
 
-    for n_epoch in range(args.nb_epochs):
-        one_complete_epoch(
-            model,
-            n_epoch,
-            train_c_quizzes=None,
-            test_c_quizzes=None,
-            local_device=main_device,
-        )
+    if args.resume:
         filename = f"aebn_{model.id:03d}.pth"
-        torch.save(
-            {
-                "state_dict": model.state_dict(),
-                "optimizer_state_dict": model.optimizer.state_dict(),
-                "test_accuracy": model.test_accuracy,
-                "nb_epochs": model.nb_epochs,
-            },
+
+        d = torch.load(
             os.path.join(args.result_dir, filename),
+            map_location="cpu",
+            weights_only=False,
         )
+        model.load_state_dict(d["state_dict"])
+        model.optimizer.load_state_dict(d["optimizer_state_dict"])
+        model.test_accuracy = d["test_accuracy"]
+        model.nb_epochs = d["nb_epochs"]
+        log_string(f"successfully loaded {filename} nb_epochs {model.nb_epochs}")
+
+    else:
+        for n_epoch in range(args.nb_epochs):
+            one_complete_epoch(
+                model,
+                n_epoch,
+                train_c_quizzes=None,
+                test_c_quizzes=None,
+                local_device=main_device,
+            )
+            filename = f"aebn_{model.id:03d}.pth"
+            torch.save(
+                {
+                    "state_dict": model.state_dict(),
+                    "optimizer_state_dict": model.optimizer.state_dict(),
+                    "test_accuracy": model.test_accuracy,
+                    "nb_epochs": model.nb_epochs,
+                },
+                os.path.join(args.result_dir, filename),
+            )
 
     exit(0)