Update. fast
authorFrançois Fleuret <francois@fleuret.org>
Wed, 21 Aug 2024 14:40:04 +0000 (16:40 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 21 Aug 2024 14:40:04 +0000 (16:40 +0200)
grids.py
main.py
mygpt.py
quiz_machine.py

index b12b4d6..c44e527 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -226,7 +226,7 @@ class Grids(problem.Problem):
         mask_ar = quizzes.new_zeros(quizzes.size())
 
         S = self.height * self.width
-        a = mask_ar.reshape(mask_ar.size(0), 4, S + 1)[:, :, 1:]
+        a = mask_ar.view(mask_ar.size(0), 4, S + 1)[:, :, 1:]
         a[:, 0, :] = quad[0]
         a[:, 1, :] = quad[1]
         a[:, 2, :] = quad[2]
diff --git a/main.py b/main.py
index 78d01ff..033b5f6 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -360,6 +360,7 @@ def run_tests(model, quiz_machine, local_device=main_device):
             output = model(
                 mygpt.BracketedSequence(input, ranks=mygpt.mask_ar_to_ranks(mask_ar))
             ).x
+
             loss_per_token = F.cross_entropy(
                 output.transpose(1, 2), targets, reduction="none"
             )
@@ -479,6 +480,45 @@ c_quizzes_procedure = [
 ######################################################################
 
 
+def model_proba_solutions(model, quizzes):
+    l = (
+        quiz_machine.models_logprobas(
+            model,
+            quizzes,
+            ("A", "f_A", "B", "f_B"),
+            (0, 0, 0, 2),
+            (0, 0, 1, 0),
+            (0, 0, 0, 1),
+        )
+        + quiz_machine.models_logprobas(
+            model,
+            quizzes,
+            ("f_A", "A", "f_B", "B"),
+            (0, 0, 0, 2),
+            (0, 0, 1, 0),
+            (0, 0, 0, 1),
+        )
+        + quiz_machine.models_logprobas(
+            model,
+            quizzes,
+            ("B", "f_B", "A", "f_A"),
+            (0, 0, 0, 2),
+            (0, 0, 1, 0),
+            (0, 0, 0, 1),
+        )
+        + quiz_machine.models_logprobas(
+            model,
+            quizzes,
+            ("f_B", "B", "f_A", "A"),
+            (0, 0, 0, 2),
+            (0, 0, 1, 0),
+            (0, 0, 0, 1),
+        )
+    )
+
+    return l.exp()
+
+
 def save_additional_results(n_epoch, model, models, c_quizzes_procedure):
     # Save generated quizzes with the successive generation steps
 
@@ -493,21 +533,7 @@ def save_additional_results(n_epoch, model, models, c_quizzes_procedure):
 
     # This is nb_quizzes x nb_models
 
-    l = [
-        quiz_machine.models_logprobas(
-            model, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
-        )
-        + quiz_machine.models_logprobas(
-            model, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
-        )
-        + quiz_machine.models_logprobas(
-            model, c_quizzes, ("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0)
-        )
-        + quiz_machine.models_logprobas(
-            model, c_quizzes, ("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0)
-        )
-        for model in models
-    ]
+    l = [model_proba_solutions(model, c_quizzes) for model in models]
 
     seq_logprobas = torch.cat([x[:, None] for x in l], dim=1)
     probas = seq_logprobas.exp()
@@ -549,25 +575,6 @@ def save_additional_results(n_epoch, model, models, c_quizzes_procedure):
 ######################################################################
 
 
-def model_proba_solutions(model, quizzes):
-    l = (
-        quiz_machine.models_logprobas(
-            model, quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
-        )
-        + quiz_machine.models_logprobas(
-            model, quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
-        )
-        + quiz_machine.models_logprobas(
-            model, quizzes, ("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0)
-        )
-        + quiz_machine.models_logprobas(
-            model, quizzes, ("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0)
-        )
-    )
-
-    return l.exp()
-
-
 def create_c_quizzes(
     main_model,
     other_models,
@@ -822,48 +829,6 @@ class Recorder(nn.Module):
         return input
 
 
-######################################################################
-
-
-def save_generated_c_quizzes(model, filename, nb=64):
-    while sum([x.size(0) for x in record]) < nb:
-        model = models[torch.randint(len(models), (1,)).item()]
-        c_quizzes = quiz_machine.generate_c_quizzes(
-            64,
-            model_for_generation=model,
-            procedure=c_quizzes_procedure,
-        )
-
-        p = quiz_machine.models_logprobas(
-            model,
-            c_quizzes,
-            ("A", "f_A", "B", "f_B"),
-            (1, 1, 1, 1),
-            temperature=1,
-        ).exp()
-
-        p_hot = quiz_machine.models_logprobas(
-            model,
-            c_quizzes,
-            ("A", "f_A", "B", "f_B"),
-            (1, 1, 1, 1),
-            temperature=args.temperature_hot,
-        ).exp()
-
-        to_keep = p_hot * torch.rand(p_hot.size(), device=p_hot.device) >= p
-        record.append(c_quizzes[to_keep])
-
-        print("NB_KEPT", sum([x.size(0) for x in record]))
-
-    quiz_machine.problem.save_quizzes_as_image(
-        args.result_dir,
-        filename,
-        quizzes=c_quizzes,
-    )
-
-    log_string(f"wrote {filename}")
-
-
 ######################################################################
 
 for n_epoch in range(current_epoch, args.nb_epochs):
index cd5b580..dc00423 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -83,6 +83,12 @@ def mask_ar_to_ranks(mask_ar):
     return a
 
 
+# mask_ar = torch.tensor([[ 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1]])
+# print(mask_ar)
+# print(mask_ar_to_ranks(mask_ar))
+# exit(0)
+
+
 class BracketedSequence:
     def __init__(self, x, first=None, nb=None, ranks=None):
         self.x = x
@@ -193,6 +199,7 @@ class QKVAttention(nn.Module):
         dim_qk,
         dim_v,
         nb_heads=1,
+        first_one=False,
         attention_dropout=0.0,
     ):
         super().__init__()
@@ -201,6 +208,8 @@ class QKVAttention(nn.Module):
             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
 
         self.attention_dropout = attention_dropout
+        self.first_one = first_one
+
         self.record_attention = False
 
         self.w_q = randw(nb_heads, dim_qk, dim_in)
@@ -243,19 +252,26 @@ class QKVAttention(nn.Module):
 
         t = torch.arange(x_q.size(1), device=a.device)
 
-        if bs_q.ranks is not None:
-            a = a.masked_fill(
-                (
-                    bs_q.ranks[:, None, bs_q.first : bs_q.first + bs_q.nb, None]
-                    <= bs_kv.ranks[:, None, None, : bs_kv.first + bs_kv.nb]
-                )
-                & (
-                    t[None, None, bs_q.first : bs_q.first + bs_q.nb, None]
-                    != t[None, None, None, : bs_kv.first + bs_kv.nb]
-                ),
-                float("-inf"),
+        assert bs_q.ranks is not None
+
+        # rank_forward = (
+        # bs_q.ranks[:, None, bs_q.first : bs_q.first + bs_q.nb, None]
+        # >= bs_kv.ranks[:, None, None, : bs_kv.first + bs_kv.nb]
+        # )
+
+        if self.first_one:
+            rank_forward = (
+                t[None, None, bs_q.first : bs_q.first + bs_q.nb, None]
+                <= t[None, None, None, : bs_kv.first + bs_kv.nb]
+            )
+        else:
+            rank_forward = (
+                t[None, None, bs_q.first : bs_q.first + bs_q.nb, None]
+                < t[None, None, None, : bs_kv.first + bs_kv.nb]
             )
 
+        a = a.masked_fill(rank_forward, float("-inf"))
+
         a = a.softmax(dim=3)
 
         if self.record_attention:
@@ -269,7 +285,7 @@ class QKVAttention(nn.Module):
 
         self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o
 
-        return BracketedSequence(self.cache_y, bs_q.first, bs_q.nb)
+        return BracketedSequence(self.cache_y, bs_q.first, bs_q.nb, bs_q.ranks)
 
 
 ##############################
@@ -347,7 +363,16 @@ class MyGPT(nn.Module):
 
         self.positional_encoding = AddPositionalEncoding(len_max)
 
-        trunk_blocks = []
+        trunk_blocks = [
+            QKVAttention(
+                dim_in=dim_model,
+                dim_qk=dim_keys,
+                dim_v=dim_model // nb_heads,
+                nb_heads=nb_heads,
+                first_one=True,
+                attention_dropout=dropout,
+            )
+        ]
 
         for b in range(nb_blocks):
             trunk_blocks += [
@@ -394,7 +419,7 @@ class MyGPT(nn.Module):
         for m in self.modules():
             m.loss = 0
 
-        bs = self.shifter(bs)
+        bs = self.shifter(bs)
         bs = self.embedding(bs)
         bs = self.positional_encoding(bs)
         bs = self.trunk(bs)
index 8cec909..db58461 100755 (executable)
@@ -39,17 +39,15 @@ def one_batch_masked_inplace_autoregression(
 
     indices_1 = list(((mask_ar == 1).long().sum(0) > 0).nonzero()) + [mask.size(1)]
 
+    ranks = mygpt.mask_ar_to_ranks(mask_ar)
+
     if to_generate.min() > 0:
         model(
-            BracketedSequence(input, 0, to_generate.min())
+            BracketedSequence(input, 0, to_generate.min(), ranks=ranks)
         )  # Needed to initialize the model's cache
 
-    s = to_generate.min()
-
     for s, u in zip(indices_1[:-1], indices_1[1:]):
-        logits = model(
-            BracketedSequence(input, s, u - s, ranks=mygpt.mask_ar_to_ranks(mask_ar))
-        ).x
+        logits = model(BracketedSequence(input, s, u - s, ranks=ranks)).x
 
         if deterministic_synthesis:
             t_next = logits.argmax(dim=2)
@@ -90,10 +88,10 @@ class QuizMachine:
 
         #  - struct, quad_generate, quad_noise, quad_loss
         self.train_structures = [
-            (("A", "f_A", "B", "f_B"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)),
-            (("f_A", "A", "f_B", "B"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)),
-            (("B", "f_B", "A", "f_A"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)),
-            (("f_B", "B", "f_A", "A"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)),
+            (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
+            (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
+            (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
+            (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
             (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
         ]
 
@@ -296,8 +294,9 @@ class QuizMachine:
         model,
         c_quizzes,
         struct,
+        mask_ar,
+        mask_noise,
         mask_loss,
-        mask_noise=None,
         temperature=1.0,
         device=None,
     ):
@@ -322,13 +321,20 @@ class QuizMachine:
 
             for input, l in zip(
                 c_quizzes.split(self.batch_size),
+                mask_ar.split(self.batch_size),
                 seq_logprobas.split(self.batch_size),
             ):
                 input = input.to(device)
                 quiz_mask_loss = self.make_quiz_mask(
                     input, struct=struct, mask=mask_loss
                 )
-                output = model(mygpt.BracketedSequence(input)).x / temperature
+                output = (
+                    model(
+                        mygpt.BracketedSequence(input),
+                        ranks=mygpt.mask_ar_to_ranks(mask_ar),
+                    ).x
+                    / temperature
+                )
                 l[...] = (
                     -F.cross_entropy(output.transpose(1, 2), input, reduction="none")
                     * quiz_mask_loss