Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 9 Oct 2024 20:54:00 +0000 (22:54 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 9 Oct 2024 20:54:00 +0000 (22:54 +0200)
attae.py
main.py

index 5d9da2e..3eb6c4e 100755 (executable)
--- a/attae.py
+++ b/attae.py
@@ -3,7 +3,7 @@
 # Any copyright is dedicated to the Public Domain.
 # https://creativecommons.org/publicdomain/zero/1.0/
 
-import math
+import math, warnings
 
 import torch
 
@@ -33,6 +33,27 @@ class VaswaniPositionalEncoding(nn.Module):
 ######################################################################
 
 
+class BlockRandomPositionalEncoding(nn.Module):
+    def __init__(self, dim, block_size, nb_blocks):
+        super().__init__()
+        self.pe_inside = nn.Parameter(torch.randn(1, block_size, dim) / math.sqrt(dim))
+        self.pe_per_blocks = nn.Parameter(
+            torch.randn(1, nb_blocks, dim) / math.sqrt(dim)
+        )
+
+    def forward(self, x):
+        pe = self.pe_inside.repeat(
+            x.size(0), self.pe_per_blocks.size(1), 1
+        ) + self.pe_per_blocks.repeat_interleave(self.pe_inside.size(1), dim=1).repeat(
+            x.size(0), 1, 1
+        )
+        y = x + pe
+        return y
+
+
+######################################################################
+
+
 class WithResidual(nn.Module):
     def __init__(self, *f):
         super().__init__()
@@ -171,9 +192,15 @@ class AttentionAE(nn.Module):
 
     def forward(self, x):
         x = self.embedding(x)
+
+        warnings.warn("flipping order for symmetry check", RuntimeWarning)
+        x = torch.cat([x[:, 200:], x[:, :200]], dim=1)
         x = self.positional_encoding(x)
+        x = torch.cat([x[:, 200:], x[:, :200]], dim=1)
+
         x = self.trunk(x)
         x = self.readout(x)
+
         return x
 
 
@@ -330,6 +357,9 @@ class Reasoning(nn.Module):
         )
 
     def forward(self, x_q):
+        #!!!!!!!!!!!!!!!!!!!!
+        # x_q = torch.cat([x_q[:,200:,:], x_q[:,:200,:]],dim=1)
+
         T, S = x_q.size(1), self.x_star.size(0)
         nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks
 
@@ -339,12 +369,6 @@ class Reasoning(nn.Module):
         x = self.trunk_joint(x)
 
         f, x = x[:, :S, :], x[:, S:, :]
-
-        if hasattr(self, "forced_f") and self.forced_f is not None:
-            f = self.forced_f.clone()
-
-        self.pred_f = f.clone()
-
         x = x.reshape(nb * nc, T // nc, dim)
         f = f.repeat(nc, 1, 1)
         x = torch.cat([f, x], dim=1)
@@ -353,6 +377,9 @@ class Reasoning(nn.Module):
         x = x[:, S:, :]
         x = x.reshape(nb, T, dim)
 
+        #!!!!!!!!!!!!!!!!!!!!
+        # x = torch.cat([x[:,200:,:], x[:,:200,:]],dim=1)
+
         return x
 
 
diff --git a/main.py b/main.py
index 618a62e..ba5c6e2 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -558,13 +558,6 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device):
         )
         loss = (loss_per_token * masks).mean()
 
-        if args.test == "aebn":
-            error = 0
-            for m in model.modules():
-                if hasattr(m, "error"):
-                    error = error + m.error
-            loss = loss + error
-
         acc_loss += loss.item() * imt.size(0)
         nb_samples += imt.size(0)
 
@@ -574,14 +567,6 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device):
             if nb_samples % args.batch_size == 0:
                 model.optimizer.step()
 
-    if args.test == "aebn":
-        nb_me = []
-        for m in model.modules():
-            if hasattr(m, "nb_me"):
-                nb_me.append(m.nb_me.item())
-
-        log_string(f"{label}_error {n_epoch} model {model.id} {error} nb_me {nb_me}")
-
     log_string(f"{label}_loss {n_epoch} model {model.id} {acc_loss/nb_samples}")
 
 
@@ -613,6 +598,7 @@ def save_inference_images(model, n_epoch, c_quizzes, c_quiz_multiplier, local_de
     # Save some images of the ex nihilo generation of the four grids
 
     result = ae_generate(model, 150, local_device=local_device).to("cpu")
+
     problem.save_quizzes_as_image(
         args.result_dir,
         f"culture_generation_{n_epoch}_{model.id}.png",
@@ -941,10 +927,9 @@ if args.test == "aebn":
         len_max=1e4,
     )
 
-    model.id = 0
-    model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
-    model.test_accuracy = 0.0
-    model.nb_epochs = 0
+    model.positional_encoding = attae.BlockRandomPositionalEncoding(
+        args.dim_model, 100, 4
+    )
 
     model.trunk = attae.Reasoning(
         nb_f_tokens=25,
@@ -957,16 +942,10 @@ if args.test == "aebn":
         attention_dropout=args.dropout,
     )
 
-    # model.trunk = model.trunk[: len(model.trunk) // 2] + nn.Sequential(
-    # attae.LearningToBeMe(f, g, 1e-1)
-    # )
-
-    # model.id = 0
-    # model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
-    # model.test_accuracy = 0.0
-    # model.nb_epochs = 0
-
+    model.id = 0
     model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+    model.test_accuracy = 0.0
+    model.nb_epochs = 0
 
     for n_epoch in range(args.nb_epochs):
         one_complete_epoch(