Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 25 Jul 2024 16:15:40 +0000 (18:15 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 25 Jul 2024 16:15:40 +0000 (18:15 +0200)
grids.py
main.py
quiz_machine.py

index 1d94e07..67a5c97 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -233,8 +233,7 @@ class Grids(problem.Problem):
         max_nb_cached_chunks=None,
         chunk_size=None,
         nb_threads=-1,
-        world_tasks=None,
-        science_tasks=None,
+        tasks=None,
     ):
         self.colors = torch.tensor([c for _, c in self.named_colors])
 
@@ -264,7 +263,7 @@ class Grids(problem.Problem):
 
         self.cache_rec_coo = {}
 
-        self.all_tasks = [
+        all_tasks = [
             self.task_replace_color,
             self.task_translate,
             self.task_grow,
@@ -285,17 +284,10 @@ class Grids(problem.Problem):
             # self.task_islands, # TOO MESSY
         ]
 
-        if world_tasks is None:
-            self.world_tasks = self.all_tasks
+        if tasks is None:
+            self.all_tasks = all_tasks
         else:
-            self.world_tasks = [
-                getattr(self, "task_" + t) for t in world_tasks.split(",")
-            ]
-
-        if science_tasks is not None:
-            self.science_tasks = [
-                getattr(self, "task_" + t) for t in science_tasks.split(",")
-            ]
+            self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")]
 
         super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
 
@@ -1445,6 +1437,59 @@ class Grids(problem.Problem):
                 rs = (S // (2**j)) % 2
                 f_X[2, -j - 1] = c[2 + rs]
 
+    def task_science_implicit(self, A, f_A, B, f_B):
+        nb_rec = 5
+        c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                i1, i2 = torch.randint(self.height, (2,)).sort().values
+                if i1 >= 1 and i2 < self.height and i1 + 3 < i2:
+                    break
+
+            while True:
+                j1, j2 = torch.randint(self.width, (2,)).sort().values
+                if j1 >= 1 and j2 < self.width and j1 + 3 < j2:
+                    break
+
+            f_X[i1:i2, j1:j2] = c[0]
+
+            # ---------------------
+
+            while True:
+                ii1, ii2 = torch.randint(self.height, (2,)).sort().values
+                if ii1 >= i1 and ii2 <= i2 and ii1 + 1 < ii2:
+                    break
+            jj = torch.randint(j1, (1,))
+            X[ii1:ii2, jj:j1] = c[1]
+            f_X[ii1:ii2, jj:j1] = c[1]
+
+            while True:
+                ii1, ii2 = torch.randint(self.height, (2,)).sort().values
+                if ii1 >= i1 and ii2 <= i2 and ii1 + 1 < ii2:
+                    break
+            jj = torch.randint(self.width - j2, (1,)) + j2 + 1
+            X[ii1:ii2, j2:jj] = c[2]
+            f_X[ii1:ii2, j2:jj] = c[2]
+
+            # ---------------------
+
+            while True:
+                jj1, jj2 = torch.randint(self.width, (2,)).sort().values
+                if jj1 >= j1 and jj2 <= j2 and jj1 + 1 < jj2:
+                    break
+            ii = torch.randint(i1, (1,))
+            X[ii:i1, jj1:jj2] = c[3]
+            f_X[ii:i1, jj1:jj2] = c[3]
+
+            while True:
+                jj1, jj2 = torch.randint(self.width, (2,)).sort().values
+                if jj1 >= j1 and jj2 <= j2 and jj1 + 1 < jj2:
+                    break
+            ii = torch.randint(self.height - i2, (1,)) + i2 + 1
+            X[i2:ii, jj1:jj2] = c[4]
+            f_X[i2:ii, jj1:jj2] = c[4]
+
     # end_tasks
 
     ######################################################################
@@ -1459,14 +1504,11 @@ class Grids(problem.Problem):
 
         return quizzes
 
-    def generate_w_quizzes_(self, nb, tasks=None, science=False, progress_bar=False):
+    def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False):
         S = self.height * self.width
 
         if tasks is None:
-            if science:
-                tasks = self.science_tasks
-            else:
-                tasks = self.world_tasks
+            tasks = self.all_tasks
 
         quizzes = self.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B"))
 
@@ -1487,10 +1529,9 @@ class Grids(problem.Problem):
 
         return quizzes
 
-    def save_some_examples(self, result_dir, science=False):
+    def save_some_examples(self, result_dir):
         nb, nrow = 128, 4
-        tasks = self.science_tasks if science else self.world_tasks
-        for t in tasks:
+        for t in self.all_tasks:
             print(t.__name__)
             quizzes = self.generate_w_quizzes_(nb, tasks=[t])
             self.save_quizzes_as_image(
@@ -1545,9 +1586,9 @@ if __name__ == "__main__":
     nb, nrow = 128, 4
     # nb, nrow = 8, 2
 
-    # for t in grids.world_tasks:
+    # for t in grids.all_tasks:
 
-    for t in [grids.task_path]:
+    for t in [grids.task_science_implicit]:
         print(t.__name__)
         quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
         grids.save_quizzes_as_image(
@@ -1557,7 +1598,7 @@ if __name__ == "__main__":
             comments=[f"{t.__name__} #{k}" for k in range(quizzes.size(0))],
         )
 
-    exit(0)
+    exit(0)
 
     nb = 1000
 
diff --git a/main.py b/main.py
index b49fa06..4d618cc 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -125,14 +125,6 @@ parser.add_argument(
     help="A comma-separated subset of: " + grids_tasks + ", or None for all.",
 )
 
-assert (
-    len(
-        set(args.grids_world_tasks.split(","))
-        & set(args.grids_science_tasks.split(","))
-    )
-    == 0
-), "World and science task have to be disjoint"
-
 ######################################################################
 
 parser.add_argument("--sky_height", type=int, default=6)
@@ -152,6 +144,14 @@ args = parser.parse_args()
 if args.result_dir is None:
     args.result_dir = f"results_culture"
 
+assert not args.grids_science_tasks or (
+    len(
+        set(args.grids_world_tasks.split(","))
+        & set(args.grids_science_tasks.split(","))
+    )
+    == 0
+), "World and science tasks have to be disjoint"
+
 ######################################################################
 
 default_args = {
@@ -304,17 +304,25 @@ if args.problem == "sky":
         chunk_size=100,
         nb_threads=args.nb_threads,
     )
-    back_accuracy = False
 
 elif args.problem == "grids":
     problem = grids.Grids(
         max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
         chunk_size=100,
         nb_threads=args.nb_threads,
-        world_tasks=args.grids_world_tasks,
-        science_tasks=args.grids_science_tasks,
+        tasks=args.grids_world_tasks,
     )
-    back_accuracy = True
+
+    if args.grids_science_tasks is None:
+        science_w_quizzes = None
+    else:
+        science_problem = grids.Grids(
+            max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
+            chunk_size=100,
+            nb_threads=args.nb_threads,
+            tasks=args.grids_science_tasks,
+        )
+        science_w_quizzes = science_problem.generate_w_quizzes(args.nb_test_samples)
 
 else:
     raise ValueError
@@ -324,7 +332,6 @@ if not args.resume:
 
 quiz_machine = quiz_machine.QuizMachine(
     problem=problem,
-    back_accuracy=back_accuracy,
     batch_size=args.physical_batch_size,
     result_dir=args.result_dir,
     logger=log_string,
@@ -600,12 +607,6 @@ for k in range(args.nb_gpts):
 
 ######################################################################
 
-science_w_quizzes = quiz_machine.problem.generate_w_quizzes(
-    args.nb_test_samples, science=True
-)
-
-######################################################################
-
 current_epoch = 0
 
 if args.resume:
@@ -757,6 +758,24 @@ for n_epoch in range(current_epoch, args.nb_epochs):
             c_quizzes,
         )
 
+    ######################################################################
+
+    if science_w_quizzes is not None:
+        result, correct = quiz_machine.predict(
+            model=model,
+            quizzes=science_w_quizzes.to(main_device),
+            struct=("A", "f_A", "B", "f_B"),
+            mask=(0, 0, 0, 1),
+        )
+
+        nb_correct = (correct == 1).long().sum()
+        nb_total = (correct != 0).long().sum()
+        log_string(
+            f"science_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}"
+        )
+
+    ######################################################################
+
     # Renew the training samples
 
     for model in weakest_models:
index 4048b39..8e40921 100755 (executable)
@@ -115,7 +115,6 @@ class QuizMachine:
     def __init__(
         self,
         problem,
-        back_accuracy,
         batch_size,
         result_dir,
         logger,
@@ -124,7 +123,6 @@ class QuizMachine:
         super().__init__()
 
         self.problem = problem
-        self.back_accuracy = back_accuracy
         self.batch_size = batch_size
         self.device = device
         self.logger = logger