Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 16 Jul 2024 17:53:11 +0000 (19:53 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 16 Jul 2024 17:53:11 +0000 (19:53 +0200)
grids.py
main.py
quiz_machine.py

index a115f93..e4d831c 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -1069,6 +1069,68 @@ class Grids(problem.Problem):
             X[i, j] = c[1]
             f_X[...] = (A == A[i, j]) * c[1] + ((A > 0) & (A != A[i, j])) * c[0]
 
+    # @torch.compile
+    def task_stack(self, A, f_A, B, f_B):
+        N = 5
+        c = torch.randperm(len(self.colors) - 1)[:N] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            i1, j1, i2, j2 = (
+                self.height // 2 - 1,
+                self.width // 2 - 1,
+                self.height // 2 + 1,
+                self.width // 2 + 1,
+            )
+            op = torch.tensor((0, 1, 2, 3) * 4)
+            op = op[torch.randperm(op.size(0))[:9]]
+            for q in range(op.size(0)):
+                u = 3 * (q // 3)
+                v = 3 * (q % 3)
+                d = c[torch.randint(N, (1,)).item()]
+                # X[u+1,v+1]=d
+                if op[q] == 0:  # right
+                    X[u : u + 3, v + 2] = d
+                elif op[q] == 1:  # let
+                    X[u : u + 3, v] = d
+                elif op[q] == 2:  # bottom
+                    X[u + 2, v : v + 3] = d
+                elif op[q] == 3:  # top
+                    X[u, v : v + 3] = d
+
+                if q == 0:
+                    f_X[i1:i2, j1:j2] = d
+                elif op[q] == 0:  # right
+                    f_X[i1:i2, j2] = d
+                    j2 += 1
+                elif op[q] == 1:  # let
+                    j1 -= 1
+                    f_X[i1:i2, j1] = d
+                elif op[q] == 2:  # bottom
+                    f_X[i2, j1:j2] = d
+                    i2 += 1
+                elif op[q] == 3:  # top
+                    i1 -= 1
+                    f_X[i1, j1:j2] = d
+
+    def randint(self, *m):
+        m = torch.tensor(m)
+        return (torch.rand(m.size()) * m).long()
+
+    def task_matrices(self, A, f_A, B, f_B):
+        N = 6
+        c = torch.randperm(len(self.colors) - 1)[:N] + 1
+
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            M1 = torch.randint(2, (5, 5))
+            M2 = torch.randint(2, (5, 5))
+            P = M1 @ M2
+            for i in range(5):
+                for j in range(5):
+                    X[i, j] = c[M1[i, j]]
+                    X[i, j + 5] = c[M2[i, j]]
+                    f_X[i, j] = c[M1[i, j]]
+                    f_X[i, j + 5] = c[M2[i, j]]
+                    f_X[i + 5, j + 5] = c[P[i, j]]
+
     ######################################################################
 
     def trivial_prompts_and_answers(self, prompts, answers):
@@ -1159,7 +1221,7 @@ if __name__ == "__main__":
     # nb, nrow = 8, 2
 
     # for t in grids.all_tasks:
-    for t in [grids.task_count]:
+    for t in [grids.task_matrices]:
         print(t.__name__)
         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
         grids.save_quiz_illustrations(
diff --git a/main.py b/main.py
index b149e62..957e95a 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -88,7 +88,7 @@ parser.add_argument("--proba_understands", type=float, default=0.9)
 
 parser.add_argument("--proba_not_understands", type=float, default=0.5)
 
-parser.add_argument("--generation_temperature", type=float, default=2)
+parser.add_argument("--generation_temperature", type=float, default=2)
 
 parser.add_argument("--c_quiz_validation_mode", type=str, default="predict")
 
@@ -410,25 +410,31 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
     start_time = time.perf_counter()
 
+    nb_validated = torch.zeros(len(models))
+
     while nb_validated < nb_to_create:
-        model_for_generation = models[torch.randint(len(models), (1,))]
+        # We balance the number of quizzes per model
+
+        model_for_generation = models[nb_validated.argmin()]
 
         c_quizzes = quiz_machine.generate_c_quizzes(
             nb_to_generate_per_iteration,
             model_for_generation=model_for_generation,
             forward_only=args.forward_only,
+            generation_temperature=args.generation_temperature,
         )
 
         c_quizzes = keep_good_quizzes(models, c_quizzes)
 
-        nb_validated += c_quizzes.size(0)
+        nb_validated[model.id] += c_quizzes.size(0)
+        total_nb_validated = nb_validated.sum()
 
         recorded.append(c_quizzes)
 
         duration = time.perf_counter() - start_time
 
-        if nb_validated > 0 and nb_validated < nb_to_create:
-            d = (nb_to_create - nb_validated) * duration / nb_validated
+        if total_nb_validated > 0 and total_nb_validated < nb_to_create:
+            d = (nb_to_create - total_nb_validated) * duration / total_nb_validated
             e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime(
                 "%a %H:%M"
             )
@@ -436,7 +442,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
             e = "???"
 
         log_string(
-            f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create} (finishes {e})"
+            f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e})"
         )
 
     validated_quizzes = torch.cat(recorded, dim=0)
index 008e435..0b84b36 100755 (executable)
@@ -541,7 +541,13 @@ class QuizMachine:
 
     ###############################################################
 
-    def generate_c_quizzes(self, nb, model_for_generation, forward_only=False):
+    def generate_c_quizzes(
+        self,
+        nb,
+        model_for_generation,
+        forward_only=False,
+        generation_temperature=1.0
+    ):
         c_quizzes = torch.empty(
             nb,
             self.prompt_len + self.answer_len + 2,
@@ -561,7 +567,7 @@ class QuizMachine:
                 input=c_quizzes,
                 ar_mask=self.make_ar_mask(c_quizzes, first=True),
                 seq_logproba=seq_logproba,
-                temperature=1.0,
+                temperature=generation_temperature,
                 deterministic_synthesis=False,
                 device=self.device,
             )
@@ -572,7 +578,7 @@ class QuizMachine:
                 input=c_quizzes,
                 ar_mask=self.make_ar_mask(c_quizzes),
                 seq_logproba=seq_logproba,
-                temperature=1,
+                temperature=1.0
                 deterministic_synthesis=False,
                 device=self.device,
             )
@@ -587,7 +593,7 @@ class QuizMachine:
                 input=c_quizzes,
                 ar_mask=self.make_ar_mask(c_quizzes, first=True),
                 seq_logproba=seq_logproba,
-                temperature=1.0,
+                temperature=generation_temperature,
                 deterministic_synthesis=False,
                 device=self.device,
             )