Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 9 Oct 2024 08:53:32 +0000 (10:53 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 9 Oct 2024 08:53:32 +0000 (10:53 +0200)
attae.py
main.py

index 2b231de..c386e3a 100755 (executable)
--- a/attae.py
+++ b/attae.py
@@ -286,6 +286,97 @@ class FunctionalAttentionAE(nn.Module):
 ######################################################################
 
 
+class Reasoning(nn.Module):
+    def __init__(
+        self,
+        nb_f_tokens,
+        nb_chunks,
+        dim_model,
+        dim_qk,
+        dim_hidden,
+        nb_heads=1,
+        nb_blocks=1,
+        attention=vanilla_attention,
+        attention_dropout=0.0,
+    ):
+        super().__init__()
+
+        def randw(*d):
+            return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+        self.nb_chunks = nb_chunks
+        self.x_star = randw(nb_f_tokens, dim_model)
+
+        self.trunk_joint = create_trunk(
+            dim_model=dim_model,
+            dim_keys=dim_qk,
+            dim_hidden=dim_hidden,
+            nb_heads=nb_heads,
+            nb_blocks=nb_blocks,
+            dropout=attention_dropout,
+        )
+
+        self.trunk_marginal = create_trunk(
+            dim_model=dim_model,
+            dim_keys=dim_qk,
+            dim_hidden=dim_hidden,
+            nb_heads=nb_heads,
+            nb_blocks=nb_blocks,
+            dropout=attention_dropout,
+        )
+
+    def forward(self, x_q):
+        T, S = x_q.size(1), self.x_star.size(0)
+        nb, dim = x_q.size(0), x_q.size(2)
+        x_star = self.x_star.unsqueeze(0).expand(nb, -1, -1)
+
+        x = torch.cat([x_star, x_q], dim=1)
+        x = self.trunk_joint(x)
+
+        f, x = x[:, : x_star.size(1), :], x[:, x_star.size(1) :, :]
+        x = x.reshape(nb, self.nb_chunks, T // self.nb_chunks, dim)
+        f = f.unsqueeze(1).expand(nb, self.nb_chunks, S, dim)
+        x = x.reshape(nb * self.nb_chunks, T // self.nb_chunks, dim)
+        f = f.reshape(nb * self.nb_chunks, S, dim)
+        x = torch.cat([f, x], dim=1)
+        x = self.trunk_marginal(x)
+
+        x = x[:, x_star.size(1) :, :]
+        x = x.reshape(nb, T, dim)
+
+        return x
+
+
+######################################################################
+
+
+class LearningToBeMe(nn.Module):
+    def __init__(self, real_f, mimic_f, epsilon):
+        super().__init__()
+        self.real_f = real_f
+        self.mimic_f = mimic_f
+        self.epsilon = epsilon
+
+    def forward(self, x):
+        y_real_f = self.real_f(x)
+        y_mimic_f = self.mimic_f(x.detach())
+        z_real_f = y_real_f.flatten(1)
+        z_mimic_f = y_mimic_f.flatten(1)
+        e = (
+            (z_real_f.detach() - z_mimic_f).pow(2).sum(dim=1, keepdim=True)
+            / z_real_f.detach().pow(2).sum(dim=1, keepdim=True)
+        ).sqrt()
+        self.error = e.mean()
+        m = (e <= self.epsilon).float()
+        self.nb_me = (1 - m).mean()
+        z = (1 - m) * z_real_f + m * z_mimic_f
+        y = z.reshape(y_real_f.shape)
+        return y
+
+
+######################################################################
+
+
 if __name__ == "__main__":
     model = FunctionalAttentionAE(
         vocabulary_size_in=100,
diff --git a/main.py b/main.py
index f6cf450..41862a2 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -557,6 +557,14 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device):
             logits.transpose(1, 2), targets, reduction="none"
         )
         loss = (loss_per_token * masks).mean()
+
+        if args.test == "aebn":
+            error = 0
+            for m in model.modules():
+                if hasattr(m, "error"):
+                    error = error + m.error
+            loss = loss + error
+
         acc_loss += loss.item() * imt.size(0)
         nb_samples += imt.size(0)
 
@@ -566,6 +574,14 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device):
             if nb_samples % args.batch_size == 0:
                 model.optimizer.step()
 
+    if args.test == "aebn":
+        nb_me = []
+        for m in model.modules():
+            if hasattr(m, "nb_me"):
+                nb_me.append(m.nb_me.item())
+
+        log_string(f"{label}_error {n_epoch} model {model.id} {error} nb_me {nb_me}")
+
     log_string(f"{label}_loss {n_epoch} model {model.id} {acc_loss/nb_samples}")
 
 
@@ -913,16 +929,27 @@ log_string(f"vocabulary_size {vocabulary_size}")
 ######################################################################
 
 if args.test == "aebn":
-    model = new_model()
+    model = new_model(0)
+    # f = model.trunk[len(model.trunk) // 2 :]
+    model.trunk = attae.Reasoning(
+        nb_f_tokens=250,
+        nb_chunks=2,
+        dim_model=args.dim_model,
+        dim_qk=args.dim_keys,
+        dim_hidden=args.dim_hidden,
+        nb_heads=args.nb_heads,
+        nb_blocks=args.nb_blocks // 2,
+        attention_dropout=args.dropout,
+    )
 
-    # model.trunk = (
-    # model.trunk[: len(model.trunk) // 2] + model.trunk[len(model.trunk) // 2 :]
+    # model.trunk = model.trunk[: len(model.trunk) // 2] + nn.Sequential(
+    # attae.LearningToBeMe(f, g, 1e-1)
     # )
 
-    model.id = 0
-    model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
-    model.test_accuracy = 0.0
-    model.nb_epochs = 0
+    model.id = 0
+    model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+    model.test_accuracy = 0.0
+    model.nb_epochs = 0
 
     for n_epoch in range(args.nb_epochs):
         one_complete_epoch(
@@ -933,6 +960,8 @@ if args.test == "aebn":
             local_device=main_device,
         )
 
+    exit(0)
+
 ######################################################################
 
 train_c_quizzes, test_c_quizzes = None, None