Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 22 Jul 2024 04:36:56 +0000 (06:36 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 22 Jul 2024 04:36:56 +0000 (06:36 +0200)
grids.py
main.py
quiz_machine.py

index d3e7dcc..22704b2 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -716,9 +716,9 @@ class Grids(problem.Problem):
         while True:
             error = False
 
-            N = torch.randint(5, (1,)).item() + 2
-            c = torch.zeros(N + 1, dtype=torch.int64)
-            c[1:] = torch.randperm(len(self.colors) - 1)[:N] + 1
+            N = 3
+            c = torch.zeros(N + 2, dtype=torch.int64)
+            c[1:] = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
 
             for X, f_X in [(A, f_A), (B, f_B)]:
                 if not hasattr(self, "cache_count") or len(self.cache_count) == 0:
@@ -728,7 +728,7 @@ class Grids(problem.Problem):
                             self.height,
                             self.width,
                             nb_seeds=self.height * self.width // 8,
-                            nb_iterations=self.height * self.width // 10,
+                            nb_iterations=self.height * self.width // 5,
                         )
                     )
 
@@ -739,20 +739,20 @@ class Grids(problem.Problem):
                 # V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % (
                 # c.size(0) - 1
                 # ) + 1
-                V = torch.randint(c.size(0) - 1, (X.max() + 1,)) + 1
+
+                V = torch.randint(N, (X.max() + 1,)) + 1
                 V[0] = 0
                 NB = F.one_hot(c[V]).sum(dim=0)
                 X[...] = c[V[X]]
-
-                if F.one_hot(X.flatten()).max(dim=0).values.sum().item() == N + 1:
-                    f_X[...] = 0
-                    for e in range(1, N + 1):
-                        for j in range(NB[c[e]]):
-                            if j < self.width:
-                                f_X[e - 1, j] = c[e]
-                            else:
-                                error = True
-                                break
+                f_X[...] = X
+
+                if F.one_hot(X.flatten()).max(dim=0).values.sum().item() >= 3:
+                    m = NB[c[:-1]].max()
+                    if (NB[c[:-1]] == m).long().sum() == 1:
+                        for e in range(1, N + 1):
+                            if NB[c[e]] == m:
+                                a = (f_X == c[e]).long()
+                                f_X[...] = (1 - a) * f_X + a * c[-1]
                 else:
                     error = True
                     break
@@ -1423,7 +1423,7 @@ if __name__ == "__main__":
     # nb, nrow = 8, 2
 
     # for t in grids.all_tasks:
-    for t in [grids.task_contact]:
+    for t in [grids.task_count]:
         # for t in [grids.task_symbols]:
         print(t.__name__)
         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
diff --git a/main.py b/main.py
index b7d0431..f8f8502 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -97,6 +97,8 @@ parser.add_argument("--temperature_hot", type=float, default=1.5)
 
 parser.add_argument("--temperature_cold", type=float, default=0.75)
 
+parser.add_argument("--nb_rounds", type=int, default=3)
+
 parser.add_argument("--c_quiz_validation_mode", type=str, default="predict")
 
 parser.add_argument("--p2a_only", action="store_true", default=False)
@@ -413,57 +415,27 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 ######################################################################
 
 
-def keep_good_quizzes(models, quizzes, required_nb_failures=1):
-    quizzes = quizzes[quiz_machine.non_trivial(quizzes)]
-
-    if args.c_quiz_validation_mode == "proba":
-        token_logprobas = quiz_machine.solution_token_logprobas(models, quizzes)
-        l = token_logprobas.sum(dim=-1).sort(dim=-1).values
-
-        to_keep = (l[:, 0] < math.log(args.proba_not_understands)) & (
-            l[:, 1] > math.log(args.proba_understands)
-        )
-
-    elif args.c_quiz_validation_mode == "predict":
-        nc = quiz_machine.solution_nb_correct(models, quizzes)
-
-        count_nc = tuple(
-            n.item() for n in F.one_hot(nc, num_classes=len(models) + 1).sum(dim=0)
-        )
-
-        log_string(f"nb_correct {count_nc}")
-
-        to_keep = nc == (len(models) - required_nb_failures)
-
-    else:
-        raise ValueError(f"{args.c_quiz_validation_mode=}")
-
-    if args.dirty_debug:
-        # warnings.warn("DEBUG", RuntimeWarning)
-        to_keep = torch.rand(to_keep.size(), device=to_keep.device) < 0.5
-
-    return quizzes[to_keep]
-
-
-######################################################################
-
-
 def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
-    nb_to_create = nb_for_train + nb_for_test
-    nb_to_generate_per_iteration = nb_to_create
+    nb_to_validate = nb_for_train + nb_for_test
+    nb_to_generate_per_iteration = nb_to_validate
     nb_validated = 0
 
     recorded_validated = []
-    recorded_too_simple = []
+    recorded_too_simple = []
 
     start_time = time.perf_counter()
 
-    nb_validated = torch.zeros(len(models), dtype=torch.int64)
+    nb_validated_per_model = torch.zeros(len(models), dtype=torch.int64)
 
-    while nb_validated.sum() < nb_to_create:
+    while nb_validated_per_model.sum() < nb_to_validate:
         # We balance the number of quizzes per model
 
-        model_for_generation = sorted(models, key=lambda m: nb_validated[m.id])[0]
+        model_for_generation = sorted(
+            models, key=lambda m: nb_validated_per_model[m.id]
+        )[0]
+
+        # We generate quizzes with a procedure that injects some
+        # structured noise
 
         c_quizzes = quiz_machine.generate_c_quizzes(
             nb_to_generate_per_iteration,
@@ -473,30 +445,40 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
             temperature_cold=args.temperature_cold,
         )
 
-        c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
-
-        nc = quiz_machine.solution_nb_correct(models, c_quizzes)
-
-        count_nc = tuple(
-            n.item() for n in F.one_hot(nc, num_classes=len(models) + 1).sum(dim=0)
-        )
-
-        log_string(f"nb_correct {count_nc}")
+        # We discard the trivial ones
 
-        recorded_too_simple.append(c_quizzes[nc == len(models)])
+        c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
 
-        c_quizzes = c_quizzes[nc == len(models) - 1]
+        # We go through nb_rounds rounds and keep only quizzes on
+        # which models respond always the same through rounds
 
-        nb_validated[model_for_generation.id] += c_quizzes.size(0)
-        total_nb_validated = nb_validated.sum().item()
+        total_nb_validated = 0
+        ms = 0
+        for r in range(args.nb_rounds):
+            ms += quiz_machine.models_successes(models, c_quizzes)
+            # print(f"{r=} {ms=}")
+            i = ((ms == r + 1).long().sum(dim=1) == ms.size(1) - 1) & (
+                (ms == 0).long().sum(dim=1) == 1
+            )
+            c_quizzes = c_quizzes[i]
+            ms = ms[i]
+            if c_quizzes.size(0) == 0:
+                break
 
-        recorded_validated.append(c_quizzes)
+        if c_quizzes.size(0) > 0:
+            nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0)
+            total_nb_validated = nb_validated_per_model.sum().item()
+            recorded_validated.append(c_quizzes)
 
         duration = time.perf_counter() - start_time
 
         if total_nb_validated > 0:
-            if total_nb_validated < nb_to_create:
-                d = (nb_to_create - total_nb_validated) * duration / total_nb_validated
+            if total_nb_validated < nb_to_validate:
+                d = (
+                    (nb_to_validate - total_nb_validated)
+                    * duration
+                    / total_nb_validated
+                )
                 e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime(
                     "%a %H:%M"
                 )
@@ -506,11 +488,11 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
             e = "???"
 
         log_string(
-            f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)"
+            f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)"
         )
 
     validated_quizzes = torch.cat(recorded_validated, dim=0)
-    too_simple_quizzes = torch.cat(recorded_too_simple, dim=0)
+    too_simple_quizzes = torch.cat(recorded_too_simple, dim=0)
 
     ######################################################################
     # store the new c_quizzes which have been validated
@@ -519,7 +501,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
     quiz_machine.store_c_quizzes(v_train, for_train=True)
     quiz_machine.store_c_quizzes(quiz_machine.problem.p_a_flip(v_train), for_train=True)
 
-    v_test = validated_quizzes[nb_for_train:nb_to_create]
+    v_test = validated_quizzes[nb_for_train:nb_to_validate]
     quiz_machine.store_c_quizzes(v_test, for_train=False)
     quiz_machine.store_c_quizzes(quiz_machine.problem.p_a_flip(v_test), for_train=False)
 
@@ -534,13 +516,13 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
             args.result_dir, prefix, vq, show_part_to_predict=False
         )
 
-    vq = too_simple_quizzes[torch.randperm(too_simple_quizzes.size(0))[:128]]
+    vq = too_simple_quizzes[torch.randperm(too_simple_quizzes.size(0))[:128]]
 
-    if vq.size(0) > 0:
-        prefix = f"culture_c_quiz_{n_epoch:04d}_too_simple"
-        quiz_machine.save_quiz_illustrations(
-            args.result_dir, prefix, vq, show_part_to_predict=False
-        )
+    if vq.size(0) > 0:
+    # prefix = f"culture_c_quiz_{n_epoch:04d}_too_simple"
+    # quiz_machine.save_quiz_illustrations(
+    # args.result_dir, prefix, vq, show_part_to_predict=False
+    # )
 
 
 ######################################################################
index b965e33..5f14528 100755 (executable)
@@ -59,8 +59,8 @@ def one_batch_masked_inplace_autoregression(
     input,
     ar_mask,
     seq_logproba,
-    temperature,
-    deterministic_synthesis,
+    logit_transformer=None,
+    deterministic_synthesis=False,
 ):
     to_generate = (ar_mask.sum(0) > 0).nonzero()
 
@@ -73,7 +73,7 @@ def one_batch_masked_inplace_autoregression(
 
         logits = output[:, s]
 
-        logits = (logits / temperature).log_softmax(dim=-1)
+        logits = logit_transformer(s, logits).log_softmax(dim=-1)
 
         if deterministic_synthesis:
             t_next = logits.argmax(-1)
@@ -94,8 +94,8 @@ def masked_inplace_autoregression(
     input,
     ar_mask,
     seq_logproba,
-    temperature,
-    deterministic_synthesis,
+    logit_transformer=None,
+    deterministic_synthesis=False,
     forbidden_tokens=None,
     logit_biases=None,
     progress_bar_desc=None,
@@ -127,7 +127,7 @@ def masked_inplace_autoregression(
                 input=input,
                 ar_mask=ar_mask,
                 seq_logproba=seq_logproba,
-                temperature=temperature,
+                logit_transformer=logit_transformer,
                 deterministic_synthesis=deterministic_synthesis,
             )
 
@@ -305,7 +305,6 @@ class QuizMachine:
                 input=result,
                 ar_mask=ar_mask,
                 seq_logproba=seq_logproba,
-                temperature=1.0,
                 deterministic_synthesis=deterministic_synthesis,
                 progress_bar_desc="accuracy",
                 device=self.device,
@@ -447,15 +446,14 @@ class QuizMachine:
 
     ###############################################################
 
-    def solution_nb_correct(self, models_for_validation, c_quizzes):
+    def models_successes(self, models_for_validation, c_quizzes):
         seq_logproba = torch.zeros(
             c_quizzes.size(0),
             max([m.id for m in models_for_validation]) + 1,
             device=self.device,
         )
 
-        nb_correct = 0
-        correct_models = torch.empty(
+        correctly_solved = torch.empty(
             c_quizzes.size(0),
             max([m.id for m in models_for_validation]) + 1,
             device=self.device,
@@ -477,14 +475,11 @@ class QuizMachine:
                 input=result,
                 ar_mask=ar_mask,
                 seq_logproba=seq_logproba[:, model.id],
-                temperature=1.0,
                 deterministic_synthesis=False,
                 device=self.device,
             )
 
-            correct_models[:, model.id] = (
-                (c_quizzes == result).long().min(dim=-1).values
-            )
+            correct = (c_quizzes == result).long().min(dim=-1).values
 
             # -------------------------------
 
@@ -502,22 +497,17 @@ class QuizMachine:
                 input=result,
                 ar_mask=ar_mask,
                 seq_logproba=seq_logproba[:, model.id],
-                temperature=1.0,
                 deterministic_synthesis=False,
                 device=self.device,
             )
 
-            correct_models[:, model.id] *= (
-                (c_quizzes == result).long().min(dim=-1).values
-            )
+            correct *= (c_quizzes == result).long().min(dim=-1).values
 
             # -------------------------------
 
-        i = correct_models.sum(dim=1) == correct_models.size(1) - 1
-        c = (correct_models[i] == 0).long().sum(dim=0)
-        self.logger(f"nb_failures_on_validated {tuple(x.item() for x in c)}")
+            correctly_solved[:, model.id] = correct
 
-        return correct_models.sum(dim=1).to("cpu")
+        return correctly_solved.to("cpu")
 
     ###############################################################
 
@@ -538,6 +528,9 @@ class QuizMachine:
 
         seq_logproba = torch.zeros(nb, device=self.device)
 
+        def heater(T):
+            return lambda s, logits: logits / T
+
         if p2a_only:
             c_quizzes[...] = self.problem.token_forward
 
@@ -547,7 +540,7 @@ class QuizMachine:
                 input=c_quizzes,
                 ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
                 seq_logproba=seq_logproba,
-                temperature=temperature_hot,
+                logit_transformer=heater(temperature_hot),
                 deterministic_synthesis=False,
                 device=self.device,
             )
@@ -558,7 +551,7 @@ class QuizMachine:
                 input=c_quizzes,
                 ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
                 seq_logproba=seq_logproba,
-                temperature=temperature_cold,
+                logit_transformer=heater(temperature_cold),
                 deterministic_synthesis=False,
                 device=self.device,
             )
@@ -572,7 +565,7 @@ class QuizMachine:
                 input=c_quizzes,
                 ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
                 seq_logproba=seq_logproba,
-                temperature=temperature_hot,
+                logit_transformer=heater(temperature_hot),
                 deterministic_synthesis=False,
                 device=self.device,
             )
@@ -583,7 +576,7 @@ class QuizMachine:
                 input=c_quizzes,
                 ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
                 seq_logproba=seq_logproba,
-                temperature=temperature_cold,
+                logit_transformer=heater(temperature_cold),
                 deterministic_synthesis=False,
                 device=self.device,
             )
@@ -596,7 +589,7 @@ class QuizMachine:
                 input=c_quizzes,
                 ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
                 seq_logproba=seq_logproba,
-                temperature=temperature_cold,
+                logit_transformer=heater(temperature_cold),
                 deterministic_synthesis=False,
                 device=self.device,
             )