Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 29 Aug 2024 06:07:24 +0000 (08:07 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 29 Aug 2024 06:07:24 +0000 (08:07 +0200)
main.py

diff --git a/main.py b/main.py
index 85213cb..43a8774 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -728,6 +728,33 @@ class MultiEmbedding(nn.Module):
         return y
 
 
+def attention_block(dim_model, dim_keys, nb_heads, dropout):
+    return WithResidual(
+        CacheWrapper(
+            nn.LayerNorm((dim_model,)),
+        ),
+        QKVAttention(
+            dim_in=dim_model,
+            dim_qk=dim_keys,
+            dim_v=dim_model // nb_heads,
+            nb_heads=nb_heads,
+            attention_dropout=dropout,
+        ),
+    )
+
+
+def ffw_block(dim_model, dim_hidden, nb_heads, dropout):
+    return WithResidual(
+        CacheWrapper(
+            nn.LayerNorm((dim_model,)),
+            nn.Linear(in_features=dim_model, out_features=dim_hidden),
+            nn.ReLU(),
+            nn.Linear(in_features=dim_hidden, out_features=dim_model),
+            nn.Dropout(dropout),
+        ),
+    )
+
+
 class MyAttentionAE(nn.Module):
     def __init__(
         self,
@@ -756,40 +783,9 @@ class MyAttentionAE(nn.Module):
         trunk_blocks = []
 
         for b in range(nb_blocks):
-            # if b == nb_blocks//2:
-            # trunk_blocks += [
-            # QKVAttention(
-            # dim_in=dim_model,
-            # dim_qk=dim_keys,
-            # dim_v=dim_model // nb_heads,
-            # nb_heads=nb_heads,
-            # attention_dropout=dropout,
-            # ),
-            # VaswaniPositionalEncoding(len_max=1e5)
-            # ]
-
             trunk_blocks += [
-                WithResidual(
-                    CacheWrapper(
-                        nn.LayerNorm((dim_model,)),
-                    ),
-                    QKVAttention(
-                        dim_in=dim_model,
-                        dim_qk=dim_keys,
-                        dim_v=dim_model // nb_heads,
-                        nb_heads=nb_heads,
-                        attention_dropout=dropout,
-                    ),
-                ),
-                WithResidual(
-                    CacheWrapper(
-                        nn.LayerNorm((dim_model,)),
-                        nn.Linear(in_features=dim_model, out_features=dim_hidden),
-                        nn.ReLU(),
-                        nn.Linear(in_features=dim_hidden, out_features=dim_model),
-                        nn.Dropout(dropout),
-                    ),
-                ),
+                attention_block(dim_model, dim_keys, nb_heads, dropout),
+                ffw_block(dim_model, dim_hidden, nb_heads, dropout),
             ]
 
         self.trunk = nn.Sequential(*trunk_blocks)
@@ -816,6 +812,135 @@ class MyAttentionAE(nn.Module):
         return bs
 
 
+######################################################################
+
+# f = phi(A, f(A)) + phi(B, f(B))
+# \hat{f(A)} = psi(A, f)
+# \hat{A} = psi_inv(f(A), f)
+# \hat{f(B)} = psi(B, f)
+# \hat{B} = psi_inv(f(B), f)
+
+
+def attention_layer(dim_model, dim_keys, nb_heads, dropout):
+    return WithResidual(
+        CacheWrapper(
+            nn.LayerNorm((dim_model,)),
+        ),
+        QKVAttention(
+            dim_in=dim_model,
+            dim_qk=dim_keys,
+            dim_v=dim_model // nb_heads,
+            nb_heads=nb_heads,
+            attention_dropout=dropout,
+        ),
+    )
+
+
+class FunctionalAE(nn.Module):
+    def __init__(
+        self,
+        vocabulary_size,
+        dim_model,
+        dim_keys,
+        dim_hidden,
+        nb_heads,
+        nb_blocks,
+        dropout=0.0,
+        len_max=1024,
+    ):
+        super().__init__()
+
+        assert dim_model % nb_heads == 0
+
+        self.embedding = CacheWrapper(
+            nn.Sequential(
+                MultiEmbedding((vocabulary_size, 2), dim_model), nn.Dropout(dropout)
+            ),
+        )
+
+        # self.positional_encoding = TrainablePositionalEncoding(dim_model, len_max)
+        self.positional_encoding = VaswaniPositionalEncoding(len_max=1e5)
+
+        def trunk(nb, bottom=True):
+            trunk_blocks = []
+
+            la = [
+                QKVAttention(
+                    dim_in=dim_model,
+                    dim_qk=dim_keys,
+                    dim_v=dim_model // nb_heads,
+                    nb_heads=nb_heads,
+                    attention_dropout=dropout,
+                ),
+                VaswaniPositionalEncoding(len_max=1e5),
+            ]
+
+            # if not bottom:
+            # trunk_blocks += la
+
+            for b in range(nb):
+                trunk_blocks += [
+                    attention_block(dim_model, dim_keys, nb_heads, dropout),
+                    ffw_block(dim_model, dim_hidden, nb_heads, dropout),
+                ]
+
+            # if bottom:
+            # trunk_blocks += la
+
+            return nn.Sequential(*trunk_blocks)
+
+        self.phi = trunk(nb_blocks // 2, bottom=True)
+        nb_f_tokens = 200
+        self.f_tokens = nn.Parameter(
+            torch.randn(1, nb_f_tokens, dim_model) / math.sqrt(nb_f_tokens)
+        )
+        self.psi = trunk(nb_blocks // 2, bottom=False)
+        self.psi_inv = trunk(nb_blocks // 2, bottom=False)
+        self.internal_pe = VaswaniPositionalEncoding(len_max=1e5)
+
+        self.readout = CacheWrapper(
+            nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+        )
+
+        with torch.no_grad():
+            for m in self.modules():
+                if isinstance(m, nn.Embedding):
+                    m.weight.normal_(mean=0, std=2e-2)
+                elif isinstance(m, nn.LayerNorm):
+                    m.bias.zero_()
+                    m.weight.fill_(1.0)
+
+    def forward(self, bs):
+        def cat(*x):
+            return BracketedSequence(torch.cat(x, dim=1))
+
+        if torch.is_tensor(bs):
+            return self.forward(BracketedSequence(bs)).x
+        bs = self.embedding(bs)
+        bs = self.positional_encoding(bs)
+
+        x_A, x_f_A, x_B, x_f_B = bs.x.chunk(4, dim=1)
+
+        K = self.f_tokens.size(1)
+        N, L = x_A.size()[:2]
+
+        ft = self.f_tokens.expand(N, -1, -1)
+
+        theta_A = self.phi(cat(ft, x_A, x_f_A)).x[:, :K, :]
+        theta_B = self.phi(cat(ft, x_B, x_f_B)).x[:, :K, :]
+
+        hat_f_A = self.psi(cat(x_A, theta_B)).x[:, :L]
+        hat_f_B = self.psi(cat(x_B, theta_A)).x[:, :L]
+
+        hat_A = self.psi_inv(cat(x_f_A, theta_B)).x[:, :L]
+        hat_B = self.psi_inv(cat(x_f_B, theta_A)).x[:, :L]
+
+        bs = cat(hat_A, hat_f_A, hat_B, hat_f_B)
+
+        bs = self.readout(bs)
+        return bs
+
+
 ######################################################################
 
 nb_iterations = 25
@@ -926,19 +1051,25 @@ def ae_generate(model, input, mask_generate, noise_proba, nb_iterations_max=50):
 
 
 def model_ae_proba_solutions(model, input):
-    loss = 0
+    record = []
 
-    for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
-        mask_generate = quiz_machine.make_quiz_mask(
-            quizzes=input, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
-        )
-        targets, logits = targets_and_prediction(
-            probs_iterations, model, input, mask_generate
-        )
-        loss_per_token = F.cross_entropy(
-            logits.transpose(1, 2), targets, reduction="none"
-        )
-        loss += (loss_per_token * mask_generate).sum(dim=1)
+    for q in input.split(args.batch_size):
+        loss = 0
+
+        for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
+            mask_generate = quiz_machine.make_quiz_mask(
+                quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
+            )
+            targets, logits = targets_and_prediction(
+                probs_iterations, model, q, mask_generate
+            )
+            loss_per_token = F.cross_entropy(
+                logits.transpose(1, 2), targets, reduction="none"
+            )
+            loss += (loss_per_token * mask_generate).sum(dim=1)
+        record.append(loss)
+
+    loss = torch.cat(record, dim=0)
 
     return (-loss).exp()
 
@@ -1108,7 +1239,8 @@ noise_proba = 0.05
 models = []
 
 for i in range(args.nb_models):
-    model = MyAttentionAE(
+    # model = MyAttentionAE(
+    model = FunctionalAE(
         vocabulary_size=vocabulary_size,
         dim_model=args.dim_model,
         dim_keys=args.dim_keys,
@@ -1169,7 +1301,9 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 
 ######################################################################
 
-for n_epoch in range(args.nb_epochs):
+for n_epoch in range(current_epoch, args.nb_epochs):
+    start_time = time.perf_counter()
+
     state = {
         "current_epoch": n_epoch,
         # "total_time_generating_c_quizzes": total_time_generating_c_quizzes,
@@ -1187,8 +1321,8 @@ for n_epoch in range(args.nb_epochs):
 
     # --------------------------------------------------------------------
 
-    one_ae_epoch(models[0], models, quiz_machine, n_epoch, main_device)
-    exit(0)
+    one_ae_epoch(models[0], models, quiz_machine, n_epoch, main_device)
+    exit(0)
 
     ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
     weakest_models = ranked_models[: len(gpus)]
@@ -1231,3 +1365,13 @@ for n_epoch in range(args.nb_epochs):
         )
 
         log_string(f"wrote {filename}")
+
+    # --------------------------------------------------------------------
+
+    duration = time.perf_counter() - start_time
+    str_duration = ""
+    if duration >= 60:
+        str_duration += f"{int(duration//60)}min"
+        duration = duration % 60
+    str_duration += f"{duration:.01f}s"
+    log_string(f"epoch_duration {str_duration}")