Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 22 Jul 2024 22:05:45 +0000 (00:05 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 22 Jul 2024 22:05:45 +0000 (00:05 +0200)
grids.py
main.py
mygpt.py
quiz_machine.py

index 22704b2..b531eb9 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -176,21 +176,31 @@ class Grids(problem.Problem):
                 dim=1,
             )
         else:
-            flipped = torch.cat(
+            flipped_from_forward = torch.cat(
                 [
                     quizzes[:, 3 * (S + 1) : 3 * (S + 1) + S + 1],
-                    quizzes[:, 2 * (S + 1) : 2 * (S + 1) + S + 1],
+                    quizzes[:, 0 * (S + 1) : 2 * (S + 1) + S + 1],
                     quizzes[:, 1 * (S + 1) : 1 * (S + 1) + S + 1],
+                    quizzes[:, 2 * (S + 1) : 0 * (S + 1) + S + 1],
+                ],
+                dim=1,
+            )
+            flipped_from_forward[:, torch.arange(4) * (S + 1)] = self.token_backward
+
+            flipped_from_backward = torch.cat(
+                [
+                    quizzes[:, 1 * (S + 1) : 3 * (S + 1) + S + 1],
+                    quizzes[:, 2 * (S + 1) : 2 * (S + 1) + S + 1],
+                    quizzes[:, 3 * (S + 1) : 1 * (S + 1) + S + 1],
                     quizzes[:, 0 * (S + 1) : 0 * (S + 1) + S + 1],
                 ],
                 dim=1,
             )
+            flipped_from_backward[:, torch.arange(4) * (S + 1)] = self.token_forward
+
+            m = (flipped[:, 0] == self.token_forward).long()
 
-        m = (flipped[:, 0] == self.token_forward).long()
-        flipped[:, 0 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward
-        flipped[:, 1 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward
-        flipped[:, 2 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward
-        flipped[:, 3 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward
+            flipped = m * flipped_from_forward + (1 - m) * flipped_from_backward
 
         return flipped
 
diff --git a/main.py b/main.py
index f8f8502..a540cc0 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -87,6 +87,8 @@ parser.add_argument("--gpus", type=str, default="all")
 
 parser.add_argument("--nb_gpts", type=int, default=5)
 
+parser.add_argument("--max_fail_to_validate", type=int, default=1)
+
 parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975)
 
 parser.add_argument("--proba_understands", type=float, default=0.9)
@@ -99,6 +101,8 @@ parser.add_argument("--temperature_cold", type=float, default=0.75)
 
 parser.add_argument("--nb_rounds", type=int, default=3)
 
+parser.add_argument("--noise_level", type=float, default=0)
+
 parser.add_argument("--c_quiz_validation_mode", type=str, default="predict")
 
 parser.add_argument("--p2a_only", action="store_true", default=False)
@@ -374,9 +378,21 @@ def one_epoch(model, quiz_machine, local_device=main_device):
         if nb_train_samples % args.batch_size == 0:
             optimizer.zero_grad()
 
+        targets = input
+
+        if args.noise_level > 0:
+            m = (
+                (torch.rand(targets.size(), device=targets.device) < args.noise_level)
+                & (targets != quiz_machine.problem.token_forward)
+                & (targets != quiz_machine.problem.token_backward)
+            ).long()
+            input = (1 - m) * input.clone() + m * torch.randint(
+                vocabulary_size, input.size(), device=input.device
+            )
+
         output = model(mygpt.BracketedSequence(input)).x
         loss_per_token = F.cross_entropy(
-            output.transpose(1, 2), input, reduction="none"
+            output.transpose(1, 2), targets, reduction="none"
         )
         loss = loss_per_token.mean()
         acc_train_loss += loss.item() * input.size(0)
@@ -421,7 +437,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
     nb_validated = 0
 
     recorded_validated = []
-    # recorded_too_simple = []
 
     start_time = time.perf_counter()
 
@@ -450,26 +465,33 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
         c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
 
         # We go through nb_rounds rounds and keep only quizzes on
-        # which models respond always the same through rounds
+        # which models respond always the same through rounds and one
+        # which N-1 succeed and one fails
+
+        ms = 0  # "model scores"
 
-        total_nb_validated = 0
-        ms = 0
         for r in range(args.nb_rounds):
             ms += quiz_machine.models_successes(models, c_quizzes)
-            # print(f"{r=} {ms=}")
-            i = ((ms == r + 1).long().sum(dim=1) == ms.size(1) - 1) & (
-                (ms == 0).long().sum(dim=1) == 1
+            nb_sure_and_correct = (ms == r + 1).long().sum(dim=1)
+            nb_sure_and_fail = (ms == 0).long().sum(dim=1)
+            to_keep = (
+                (nb_sure_and_correct + nb_sure_and_fail == ms.size(1))
+                & (nb_sure_and_fail >= 1)
+                & (nb_sure_and_fail <= args.max_fail_to_validate)
             )
-            c_quizzes = c_quizzes[i]
-            ms = ms[i]
+
+            c_quizzes = c_quizzes[to_keep]
+            ms = ms[to_keep]
+            print(f"Round {r} remains {c_quizzes.size(0)}")
             if c_quizzes.size(0) == 0:
                 break
 
         if c_quizzes.size(0) > 0:
             nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0)
-            total_nb_validated = nb_validated_per_model.sum().item()
             recorded_validated.append(c_quizzes)
 
+        total_nb_validated = nb_validated_per_model.sum().item()
+
         duration = time.perf_counter() - start_time
 
         if total_nb_validated > 0:
@@ -492,7 +514,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
         )
 
     validated_quizzes = torch.cat(recorded_validated, dim=0)
-    # too_simple_quizzes = torch.cat(recorded_too_simple, dim=0)
 
     ######################################################################
     # store the new c_quizzes which have been validated
@@ -516,14 +537,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
             args.result_dir, prefix, vq, show_part_to_predict=False
         )
 
-    # vq = too_simple_quizzes[torch.randperm(too_simple_quizzes.size(0))[:128]]
-
-    # if vq.size(0) > 0:
-    # prefix = f"culture_c_quiz_{n_epoch:04d}_too_simple"
-    # quiz_machine.save_quiz_illustrations(
-    # args.result_dir, prefix, vq, show_part_to_predict=False
-    # )
-
 
 ######################################################################
 
@@ -696,6 +709,19 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         )
         log_string(f"wrote {filename}")
 
+    for model in weakest_models:
+        c_quizzes = quiz_machine.generate_c_quizzes(
+            128,
+            model_for_generation=model,
+            p2a_only=args.p2a_only,
+            temperature_hot=args.temperature_hot,
+            temperature_cold=args.temperature_cold,
+        )
+
+        quiz_machine.save_quiz_illustrations(
+            args.result_dir, f"non_validated_{n_epoch:04d}_{model.id:02d}", c_quizzes
+        )
+
     # Renew the training samples
 
     for model in weakest_models:
index d0fda7e..51c0862 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -295,6 +295,23 @@ class MyGPT(nn.Module):
         bs = self.readout(bs)
         return bs
 
+    def partial_forward(self, bs, start_layer=None, end_layer=None):
+        if start_layer is None:
+            # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
+            bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
+            bs = self.embedding(bs)
+            if end_layer is not None:
+                return self.trunk[:end_layer](bs)
+            else:
+                bs = self.trunk(bs)
+                bs = self.readout(bs)
+                return bs
+        else:
+            bs = self.trunk[start_layer:](bs)
+            bs = self.trunk(bs)
+            bs = self.readout(bs)
+            return bs
+
     def record_attention(self, v=True):
         for m in self.modules():
             if isinstance(m, QKVAttention):
index a5f9a89..182e9ff 100755 (executable)
@@ -601,3 +601,217 @@ class QuizMachine:
             )
 
         return c_quizzes.to("cpu")
+
+    ######################################################################
+
+    def generate_c_quizzes_fixed_point(
+        self,
+        nb,
+        model_for_generation,
+        p2a_only=False,
+        temperature_hot=1.0,
+        temperature_cold=1.0,
+    ):
+        c_quizzes = torch.empty(
+            nb,
+            self.prompt_len + self.answer_len,
+            device=self.device,
+            dtype=torch.int64,
+        )
+
+        seq_logproba = torch.zeros(nb, device=self.device)
+
+        lt_noisy = lambda s, logits: logits / temperature_hot
+        lt_clean = lambda s, logits: logits / temperature_cold
+
+        c_quizzes[...] = self.problem.token_backward
+
+        masked_inplace_autoregression(
+            model=model_for_generation,
+            batch_size=self.batch_size,
+            input=c_quizzes,
+            ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
+            seq_logproba=seq_logproba,
+            logit_transformer=lt_noisy,
+            deterministic_synthesis=False,
+            device=self.device,
+        )
+
+        self.save_quiz_illustrations("/tmp", f"c_quizzes_before", c_quizzes)
+
+        c_quizzes = self.problem.p_a_flip(c_quizzes)
+
+        while True:
+            print("ITERATION")
+
+            c_quizzes = self.problem.p_a_flip(c_quizzes)
+
+            pred = c_quizzes.clone()
+
+            masked_inplace_autoregression(
+                model=model_for_generation,
+                batch_size=self.batch_size,
+                input=c_quizzes,
+                ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
+                seq_logproba=seq_logproba,
+                logit_transformer=lt_clean,
+                deterministic_synthesis=False,
+                device=self.device,
+            )
+
+            c_quizzes = self.problem.p_a_flip(c_quizzes)
+
+            masked_inplace_autoregression(
+                model=model_for_generation,
+                batch_size=self.batch_size,
+                input=c_quizzes,
+                ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
+                seq_logproba=seq_logproba,
+                logit_transformer=lt_clean,
+                deterministic_synthesis=False,
+                device=self.device,
+            )
+
+            if pred[202:].equal(c_quizzes[202:]):
+                break
+
+        self.save_quiz_illustrations("/tmp", f"c_quizzes_after", c_quizzes)
+
+        exit(0)
+
+        return c_quizzes.to("cpu")
+
+    ######################################################################
+
+    def generate_c_quizzes_mixing(
+        self,
+        nb,
+        model_for_generation,
+        p2a_only=False,
+        temperature_hot=1.0,
+        temperature_cold=1.0,
+    ):
+        c_quizzes = torch.empty(
+            nb,
+            self.prompt_len + self.answer_len,
+            device=self.device,
+            dtype=torch.int64,
+        )
+
+        c_quizzes_1 = torch.empty(
+            nb,
+            self.prompt_len + self.answer_len,
+            device=self.device,
+            dtype=torch.int64,
+        )
+
+        c_quizzes_2 = torch.empty(
+            nb,
+            self.prompt_len + self.answer_len,
+            device=self.device,
+            dtype=torch.int64,
+        )
+
+        seq_logproba = torch.zeros(nb, device=self.device)
+
+        lt_noisy = lambda s, logits: logits / temperature_hot
+        lt_clean = lambda s, logits: logits / temperature_cold
+
+        ######################################################################
+
+        c_quizzes_1[...] = self.problem.token_backward
+        ar_mask = self.problem.make_ar_mask(c_quizzes_1, shape="fwd_012_bck_0")
+
+        masked_inplace_autoregression(
+            model=model_for_generation,
+            batch_size=self.batch_size,
+            input=c_quizzes_1,
+            ar_mask=ar_mask,
+            seq_logproba=seq_logproba,
+            logit_transformer=lt_noisy,
+            deterministic_synthesis=False,
+            device=self.device,
+        )
+
+        self.save_quiz_illustrations("/tmp", f"c_quizzes_1", c_quizzes_1)
+
+        c_quizzes_2[...] = self.problem.token_backward
+
+        masked_inplace_autoregression(
+            model=model_for_generation,
+            batch_size=self.batch_size,
+            input=c_quizzes_2,
+            ar_mask=ar_mask,
+            seq_logproba=seq_logproba,
+            logit_transformer=lt_noisy,
+            deterministic_synthesis=False,
+            device=self.device,
+        )
+
+        self.save_quiz_illustrations("/tmp", f"c_quizzes_2", c_quizzes_2)
+
+        h = len(model_for_generation.trunk) // 2
+
+        with torch.autograd.no_grad():
+            t = model_for_generation.training
+            model_for_generation.eval()
+
+            bs1 = model_for_generation.partial_forward(
+                mygpt.BracketedSequence(c_quizzes_1), end_layer=h
+            )
+            bs2 = model_for_generation.partial_forward(
+                mygpt.BracketedSequence(c_quizzes_2), end_layer=h
+            )
+
+            alpha = 0.1
+
+            output = model_for_generation.partial_forward(
+                mygpt.BracketedSequence(alpha * bs1.x + (1 - alpha) * bs2.x),
+                start_layer=h,
+            ).x
+
+            dist = torch.distributions.categorical.Categorical(logits=output)
+            c_quizzes[...] = dist.sample()
+
+            c_quizzes[...] = (
+                ar_mask * c_quizzes + (1 - ar_mask) * self.problem.token_backward
+            )
+
+            model_for_generation.train(t)
+
+        self.save_quiz_illustrations("/tmp", f"c_quizzes", c_quizzes)
+
+        ######################################################################
+
+        masked_inplace_autoregression(
+            model=model_for_generation,
+            batch_size=self.batch_size,
+            input=c_quizzes,
+            ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
+            seq_logproba=seq_logproba,
+            logit_transformer=lt_clean,
+            deterministic_synthesis=False,
+            device=self.device,
+        )
+
+        self.save_quiz_illustrations("/tmp", f"c_quizzes_A", c_quizzes)
+
+        c_quizzes = self.problem.p_a_flip(c_quizzes)
+
+        masked_inplace_autoregression(
+            model=model_for_generation,
+            batch_size=self.batch_size,
+            input=c_quizzes,
+            ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
+            seq_logproba=seq_logproba,
+            logit_transformer=lt_clean,
+            deterministic_synthesis=False,
+            device=self.device,
+        )
+
+        self.save_quiz_illustrations("/tmp", f"c_quizzes_B", c_quizzes)
+
+        print("DONE")
+        exit(0)
+
+        return c_quizzes.to("cpu")