Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 15 Jul 2024 13:28:11 +0000 (15:28 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 15 Jul 2024 13:28:11 +0000 (15:28 +0200)
grids.py
main.py
quiz_machine.py

index eea8c6c..7752136 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -143,7 +143,7 @@ class Grids(problem.Problem):
             self.task_scale,
             self.task_symbols,
             self.task_isometry,
-            #            self.task_islands,
+            self.task_islands,
         ]
 
         if tasks is None:
@@ -617,8 +617,8 @@ class Grids(problem.Problem):
         while True:
             error = False
 
-            N = torch.randint(5, (1,)).item() + 1
-            c = torch.zeros(N + 1)
+            N = torch.randint(5, (1,)).item() + 2
+            c = torch.zeros(N + 1, dtype=torch.int64)
             c[1:] = torch.randperm(len(self.colors) - 1)[:N] + 1
 
             for X, f_X in [(A, f_A), (B, f_B)]:
@@ -635,18 +635,20 @@ class Grids(problem.Problem):
 
                 X[...] = self.cache_count.pop()
 
-                k = (X.max() + 1 + (c.size(0) - 1)).item()
-                V = torch.arange(k) // (c.size(0) - 1)
-                V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % (
-                    c.size(0) - 1
-                ) + 1
+                # k = (X.max() + 1 + (c.size(0) - 1)).item()
+                # V = torch.arange(k) // (c.size(0) - 1)
+                # V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % (
+                # c.size(0) - 1
+                # ) + 1
+                V = torch.randint(c.size(0) - 1, (X.max() + 1,)) + 1
                 V[0] = 0
+                NB = F.one_hot(c[V]).sum(dim=0)
                 X[...] = c[V[X]]
 
                 if F.one_hot(X.flatten()).max(dim=0).values.sum().item() == N + 1:
                     f_X[...] = 0
                     for e in range(1, N + 1):
-                        for j in range((X == c[e]).sum() + 1):
+                        for j in range(NB[c[e]]):
                             if j < self.width:
                                 f_X[e - 1, j] = c[e]
                             else:
@@ -659,6 +661,8 @@ class Grids(problem.Problem):
             if not error:
                 break
 
+        assert F.one_hot(A.flatten()).max(dim=0).values.sum() >= 3
+
     # @torch.compile
     def task_trajectory(self, A, f_A, B, f_B):
         c = torch.randperm(len(self.colors) - 1)[:2] + 1
@@ -1155,19 +1159,19 @@ if __name__ == "__main__":
     # nb, nrow = 8, 2
 
     # for t in grids.all_tasks:
-    for t in [grids.task_distance]:
+    for t in [grids.task_count]:
         print(t.__name__)
         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
         grids.save_quiz_illustrations(
             "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow
         )
 
-    exit(0)
+    exit(0)
 
     nb = 1000
 
     # for t in grids.all_tasks:
-    for t in [grids.task_distance]:
+    for t in [grids.task_count]:
         start_time = time.perf_counter()
         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
         delay = time.perf_counter() - start_time
diff --git a/main.py b/main.py
index cdaacdf..07fec96 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -441,32 +441,34 @@ def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
 
         duration = time.perf_counter() - start_time
 
-        if nb_validated > 0:
-            e = (nb_to_create - nb_validated) * duration / nb_validated
-            if e > 0:
-                e = "~" + str(datetime.timedelta(seconds=int(e)))
-            else:
-                e = "0s"
+        if nb_validated > 0 and nb_validated < nb_to_create:
+            d = (nb_to_create - nb_validated) * duration / nb_validated
         else:
-            e = "???"
+            d = 0
+
+        e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime(
+            "%a %H:%M"
+        )
 
         log_string(
-            f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create} (remaining time {e})"
+            f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create} (finish {e})"
         )
 
     # store the new c_quizzes which have been validated
 
-    quiz_machine.reverse_random_half_in_place(validated_quizzes)
-    quiz_machine.store_c_quizzes(validated_quizzes[:nb_for_train], for_train=True)
-    quiz_machine.store_c_quizzes(
-        validated_quizzes[nb_for_train:nb_to_create], for_train=False
-    )
+    v_train = validated_quizzes[:nb_for_train]
+    quiz_machine.store_c_quizzes(v_train, for_train=True)
+    quiz_machine.store_c_quizzes(quiz_machine.reverse_time(v_train), for_train=True)
+
+    v_test = validated_quizzes[nb_for_train:nb_to_create]
+    quiz_machine.store_c_quizzes(v_test, for_train=False)
+    quiz_machine.store_c_quizzes(quiz_machine.reverse_time(v_test), for_train=False)
 
     ######################################################################
     # save images with their logprobas
 
-    vq = validated_quizzes[:72]
-    vl = validated_logprobas[:72]
+    vq = validated_quizzes[:128]
+    vl = validated_logprobas[:128]
 
     if vq.size(0) > 0:
         prefix = f"culture_c_quiz_{n_epoch:04d}"
@@ -591,10 +593,10 @@ if args.max_percents_of_test_in_train >= 0:
 ######################################################################
 
 if args.nb_new_c_quizzes_for_train is None:
-    args.nb_new_c_quizzes_for_train = args.nb_train_samples // 50
+    args.nb_new_c_quizzes_for_train = args.nb_train_samples // 100
 
 if args.nb_new_c_quizzes_for_test is None:
-    args.nb_new_c_quizzes_for_test = args.nb_test_samples // 50
+    args.nb_new_c_quizzes_for_test = args.nb_test_samples // 100
 
 log_string(
     f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}"
index 927a349..bcb89ec 100755 (executable)
@@ -415,8 +415,8 @@ class QuizMachine:
         self.save_quiz_illustrations(
             result_dir,
             f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
-            quizzes=test_result[:72],
-            mistakes=test_correct[:72] * 2 - 1,
+            quizzes=test_result[:128],
+            mistakes=test_correct[:128] * 2 - 1,
         )
 
         return main_test_accuracy