Update. dev
authorFrançois Fleuret <francois@fleuret.org>
Sun, 13 Oct 2024 20:12:03 +0000 (22:12 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 13 Oct 2024 20:12:03 +0000 (22:12 +0200)
attae.py
main.py

index 94da984..d8c68cc 100755 (executable)
--- a/attae.py
+++ b/attae.py
@@ -140,6 +140,86 @@ class MHAttention(nn.Module):
 ######################################################################
 
 
+class ModulatedMHAttention(nn.Module):
+    def __init__(
+        self,
+        dim_model,
+        dim_qk,
+        dim_v,
+        nb_heads=1,
+        attention=vanilla_attention,
+        attention_dropout=0.0,
+    ):
+        super().__init__()
+
+        self.dim_qk = dim_qk
+
+        def randw(*d):
+            return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+        self.attention = attention
+        self.attention_dropout = attention_dropout
+        self.w_q = randw(nb_heads, dim_qk, dim_model)
+        self.w_k = randw(nb_heads, dim_qk, dim_model)
+        self.w_v = randw(nb_heads, dim_v, dim_model)
+        self.w_o = randw(nb_heads, dim_v, dim_model)
+
+    def forward(self, x_q, x_kv=None):
+        modulation, x_q = x_q[:, :, : self.dim_qk], x_q[:, :, self.dim_qk :]
+        if x_kv is None:
+            x_kv = x_q
+
+        q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q)
+        q = q * modulation.sigmoid()
+        k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k)
+        v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v)
+        y = self.attention(q, k, v)
+        y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
+
+        return torch.cat([modulation, y], dim=2)
+
+
+######################################################################
+
+
+class AttentionBlock(nn.Module):
+    def __init__(
+        self,
+        dim_model,
+        dim_keys,
+        dim_hidden,
+        nb_heads,
+        nb_blocks,
+        dropout=0.0,
+    ):
+        super().__init__()
+        self.ln1 = nn.LayerNorm((dim_model,))
+        self.mha = MHAttention(
+            dim_model=dim_model,
+            dim_qk=dim_keys,
+            dim_v=dim_model // nb_heads,
+            nb_heads=nb_heads,
+            attention=vanilla_attention,
+            attention_dropout=dropout,
+        )
+        self.ln2 = nn.LayerNorm((dim_model,))
+        self.fc1 = nn.Linear(in_features=dim_model, out_features=dim_hidden)
+        self.fc2 = nn.Linear(in_features=dim_hidden, out_features=dim_model)
+        self.drop_out = nn.Dropout(dropout)
+
+    def forward(self, x):
+        y = self.ln1(x)
+        y = self.mha(y)
+        x = x + y
+        y = self.ln2(x)
+        y = self.fc1(y)
+        y = F.relu(y)
+        y = self.fc2(y)
+        y = self.drop_out(y)
+        x = x + y
+        return x
+
+
 def create_trunk(
     dim_model,
     dim_keys,
@@ -149,35 +229,12 @@ def create_trunk(
     dropout=0.0,
     residual_masker=None,
 ):
-    trunk_blocks = []
-
-    for b in range(nb_blocks):
-        trunk_blocks += [
-            WithResidual(
-                masker=residual_masker,
-                f=(
-                    nn.LayerNorm((dim_model,)),
-                    MHAttention(
-                        dim_model=dim_model,
-                        dim_qk=dim_keys,
-                        dim_v=dim_model // nb_heads,
-                        nb_heads=nb_heads,
-                        attention=vanilla_attention,
-                        attention_dropout=dropout,
-                    ),
-                ),
-            ),
-            WithResidual(
-                masker=residual_masker,
-                f=(
-                    nn.LayerNorm((dim_model,)),
-                    nn.Linear(in_features=dim_model, out_features=dim_hidden),
-                    nn.ReLU(),
-                    nn.Linear(in_features=dim_hidden, out_features=dim_model),
-                    nn.Dropout(dropout),
-                ),
-            ),
-        ]
+    trunk_blocks = [
+        AttentionBlock(
+            dim_model, dim_keys, dim_hidden, nb_heads, nb_blocks, dropout=0.0
+        )
+        for _ in range(nb_blocks)
+    ]
 
     return nn.Sequential(*trunk_blocks)
 
@@ -289,31 +346,7 @@ class FunctionalAttentionAE(nn.Module):
             m = torch.arange(x.size(1), device=x.device) >= self.nb_work_tokens
             return m[None, :, None]
 
-        for b in range(nb_blocks):
-            trunk_blocks += [
-                WithMaskedResidual(
-                    masker,
-                    nn.LayerNorm((dim_model,)),
-                    MHAttention(
-                        dim_model=dim_model,
-                        dim_qk=dim_keys,
-                        dim_v=dim_model // nb_heads,
-                        nb_heads=nb_heads,
-                        attention=no_peek_attention,
-                        attention_dropout=dropout,
-                    ),
-                ),
-                WithMaskedResidual(
-                    masker,
-                    nn.LayerNorm((dim_model,)),
-                    nn.Linear(in_features=dim_model, out_features=dim_hidden),
-                    nn.ReLU(),
-                    nn.Linear(in_features=dim_hidden, out_features=dim_model),
-                    nn.Dropout(dropout),
-                ),
-            ]
-
-        self.trunk = nn.Sequential(*trunk_blocks)
+        self.trunk = nn.Sequential(*[AttentionBlock() for _ in range(nb_blocks)])
 
         self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size)
 
@@ -360,6 +393,8 @@ class Reasoning(nn.Module):
 
         self.nb_chunks = nb_chunks
         self.x_star = randw(nb_f_tokens, dim_model)
+        with torch.no_grad():
+            self.x_star *= 1e-3
 
         self.positional_encoding = VaswaniPositionalEncoding(len_max)
 
@@ -392,52 +427,57 @@ class Reasoning(nn.Module):
             attention_dropout=attention_dropout,
         )
 
-        self.mha_B = MHAttention(
-            dim_model=dim_model,
-            dim_qk=dim_qk,
-            dim_v=dim_model // nb_heads,
-            nb_heads=nb_heads,
-            attention=vanilla_attention,
-            attention_dropout=attention_dropout,
-        )
+    def forward_(self, x_q):
+        nb, T, dim = x_q.size()
+        nc, S = self.nb_chunks, self.x_star.size(0)
 
-    def forward_AB(self, x_q):
-        T, S = x_q.size(1), self.x_star.size(0)
-        nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks
-
-        x = x_q
-        x = x.reshape(nb, nc, T // nc, dim).reshape(nb * nc, T // nc, dim)
-        x = self.trunk_A(x)
+        x = x_q.reshape(nb * nc, T // nc, dim)
         f = self.x_star.reshape(1, S, dim).expand(nb * nc, S, dim)
-        f = self.mha_A(f, x)
+        # x = torch.cat([f, x], dim=1)
+        x = self.trunk_A(x)
 
-        k = torch.arange(nb, device=x_q.device)
-        u = f[k * 2, :]
-        f[k * 2, :] = f[k * 2 + 1, :]
-        f[k * 2 + 1, :] = u
+        x_star = self.x_star.reshape(1, S, dim)
+        f = self.mha_A(x_star, x).mean(dim=1, keepdim=True)
 
-        f = self.mha_B(x, f)
-        x = self.trunk_B(x)
-        x = x.reshape(nb, nc, T // nc, dim).reshape(nb, T, dim)
+        k = torch.arange(nb * nc, device=x_q.device)
+        k = k + 1 - 2 * (k % 2)
+        f = f[k]
+        # u = x[k * 2, :S]
+        # x[k * 2, :S] = x[k * 2 + 1, :S]
+        # x[k * 2 + 1, :S] = u
+        # x[:, S:] = x_q.reshape(nb * nc, T // nc, dim)
+
+        x = self.trunk_B(x, q_modulation=f)
+
+        x = x[:, S:].reshape(nb, T, dim)
 
         return x
 
     def forward(self, x_q):
         T, S = x_q.size(1), self.x_star.size(0)
         nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks
-        f = self.x_star.reshape(1, S, dim).expand(nb, S, dim)
 
         x = x_q
         x = x.reshape(nb, nc, T // nc, dim).reshape(nb * nc, T // nc, dim)
-        f = f.repeat(nc, 1, 1)
-        x = torch.cat([f, x], dim=1)
         x = self.trunk_A(x)
-        k = torch.arange(nb, device=x_q.device)
-        u = x[k * 2, :S]
-        x[k * 2, :S] = x[k * 2 + 1, :S]
-        x[k * 2 + 1, :S] = u
+        # f = self.x_star.reshape(1, S, dim).expand(nb * nc, S, dim)
+        # f = self.mha_A(f, x)
+
+        # if hasattr(self, "f"):
+        # if torch.is_tensor(self.f):
+        # f, self.f = self.f, f
+        # else:
+        # self.f = f
+
+        # k = torch.arange(nb, device=x_q.device)
+        # u = f[k * 2, :]
+        # f[k * 2, :] = f[k * 2 + 1, :]
+        # f[k * 2 + 1, :] = u
+
+        # x = x_q
+        # x = x.reshape(nb, nc, T // nc, dim).reshape(nb * nc, T // nc, dim)
+        # f = self.mha_B(x, f)
         x = self.trunk_B(x)
-        x = x[:, S:]
         x = x.reshape(nb, nc, T // nc, dim).reshape(nb, T, dim)
 
         return x
diff --git a/main.py b/main.py
index c8d6f10..85b916b 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -357,10 +357,11 @@ def add_input_noise_imt(imt_set, proba_input_noise):
 # Prediction
 
 
-def make_imt_samples_for_prediction(input):
+def make_imt_samples_for_prediction(input, u=None):
     nb = input.size(0)
     masks = input.new_zeros(input.size())
-    u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4)
+    if u is None:
+        u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4)
     masks.view(nb, 4, -1)[...] = u[:, :, None]
     targets = input
     input = (1 - masks) * targets
@@ -917,6 +918,56 @@ log_string(f"vocabulary_size {vocabulary_size}")
 
 ######################################################################
 
+
+def save_f_token_manipulations(model, n_epoch, local_device):
+    quizzes = generate_quiz_set(256, None, args.c_quiz_multiplier)
+
+    u = F.one_hot(torch.full((quizzes.size(0),), 3, device=local_device), num_classes=4)
+
+    imt_set = make_imt_samples_for_prediction(quizzes.to(local_device), u=u)
+
+    model.eval().to(local_device)
+
+    record = []
+
+    src = tqdm.tqdm(
+        imt_set.split(args.eval_batch_size),
+        dynamic_ncols=True,
+        desc="predict",
+        total=imt_set.size(0) // args.eval_batch_size,
+        delay=10,
+    )
+
+    N = args.eval_batch_size
+
+    for imt in src:
+        # some paranoia
+        imt = imt.clone()
+        imt[:, 0] = imt[:, 0] * (1 - imt[:, 1])
+
+        model.trunk.f = True
+        with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+            batch = imt[:, 0] + imt[:, 1] * vocabulary_size
+            logits = model(batch)
+            x = batch[N // 2 : N // 2 + 1].clone()
+            batch[N // 2 :] = batch[: N // 2]
+            batch[: N // 2] = x.expand(N // 2, -1)
+            logits = model(batch)
+
+        dist = torch.distributions.categorical.Categorical(logits=logits)
+        result = dist.sample()
+        record.append(result)
+
+    result = torch.cat(record).to("cpu")
+
+    problem.save_quizzes_as_image(
+        args.result_dir,
+        f"culture_f_token_manipulation.png",
+        quizzes=result[:128],
+        nrow=N // 2,
+    )
+
+
 if args.test == "aebn":
     model = attae.AttentionAE(
         vocabulary_size_in=vocabulary_size * 2,
@@ -944,7 +995,7 @@ if args.test == "aebn":
         pe,  # trainable=True
     )
 
-    nb_f_tokens = 200
+    nb_f_tokens = 8
 
     def no_f_residual(x):
         m = x.new_full((1, x.size(1), 1), 1.0)
@@ -981,6 +1032,8 @@ if args.test == "aebn":
         model.nb_epochs = d["nb_epochs"]
         log_string(f"successfully loaded {filename} nb_epochs {model.nb_epochs}")
 
+        save_f_token_manipulations(model, 0, local_device=main_device)
+
     else:
         for n_epoch in range(args.nb_epochs):
             one_complete_epoch(