Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 5 Sep 2024 15:28:25 +0000 (17:28 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 5 Sep 2024 15:28:25 +0000 (17:28 +0200)
grids.py
main.py

index 2717b22..9e80f62 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -284,11 +284,14 @@ class Grids(problem.Problem):
         self.cache_rec_coo = {}
 
         all_tasks = [
+            ############################################ fundamental ones
             self.task_replace_color,
             self.task_translate,
             self.task_grow,
-            self.task_half_fill,
             self.task_frame,
+            ############################################
+            ############################################
+            self.task_half_fill,
             self.task_detect,
             self.task_scale,
             self.task_symbols,
@@ -700,6 +703,27 @@ class Grids(problem.Problem):
                 X[i1:i2, j1:j2] = c[n]
                 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
 
+    # @torch.compile
+    def task_symmetry(self, A, f_A, B, f_B):
+        a, b = torch.randint(2, (2,))
+        nb_rec = 3
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            r = self.rec_coo(nb_rec, prevent_overlap=True)
+            for n in range(nb_rec):
+                i1, j1, i2, j2 = r[n]
+                X[i1:i2, j1:j2] = c[n]
+                f_X[i1:i2, j1:j2] = c[n]
+            X[: self.height // 2] = c[-1]
+            f_X[: self.height // 2] = f_X.flip([0])[: self.height // 2]
+            if a == 1:
+                X[...] = X.clone().t()
+                f_X[...] = f_X.clone().t()
+            if b == 1:
+                Z = X.clone()
+                X[...] = f_X
+                f_X[...] = Z
+
     # @torch.compile
     def task_translate(self, A, f_A, B, f_B):
         while True:
@@ -1812,7 +1836,7 @@ if __name__ == "__main__":
 
     # for t in grids.all_tasks:
 
-    for t in [grids.task_recworld_immobile]:
+    for t in [grids.task_symmetry]:
         print(t.__name__)
         w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
         grids.save_quizzes_as_image(
diff --git a/main.py b/main.py
index 174b9b8..8e938db 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -284,6 +284,24 @@ else:
 assert args.nb_train_samples % args.batch_size == 0
 assert args.nb_test_samples % args.batch_size == 0
 
+# ------------------------------------------------------
+alien_problem = grids.Grids(
+    max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
+    chunk_size=100,
+    nb_threads=args.nb_threads,
+    tasks="symmetry",
+)
+
+alien_quiz_machine = quiz_machine.QuizMachine(
+    problem=alien_problem,
+    batch_size=args.inference_batch_size,
+    result_dir=args.result_dir,
+    logger=log_string,
+    device=main_device,
+)
+
+# ------------------------------------------------------
+
 ######################################################################
 
 problem = grids.Grids(
@@ -918,7 +936,15 @@ def targets_and_prediction(model, input, mask_generate, prompt_noise=0.0):
     return targets, logits
 
 
-def run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_device):
+######################################################################
+
+
+def run_ae_test(
+    model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_device, prefix=None
+):
+    if prefix is not None:
+        prefix = prefix + "_"
+
     with torch.autograd.no_grad():
         model.eval().to(local_device)
 
@@ -940,7 +966,7 @@ def run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_
             nb_test_samples += input.size(0)
 
         log_string(
-            f"test_loss {n_epoch} model {model.id} {acc_test_loss/nb_test_samples}"
+            f"{prefix}test_loss {n_epoch} model {model.id} {acc_test_loss/nb_test_samples}"
         )
 
         # Compute the accuracy and save some images
@@ -975,15 +1001,16 @@ def run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_
             record_nd.append((result[nd], predicted_parts[nd], correct_parts[nd]))
 
         log_string(
-            f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
+            f"{prefix}test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
         )
 
-        model.test_accuracy = nb_correct / nb_total
+        if prefix is None:
+            model.test_accuracy = nb_correct / nb_total
 
         # Save some images
 
         for f, record in [("prediction", record_d), ("generation", record_nd)]:
-            filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
+            filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
 
             result, predicted_parts, correct_parts = bag_to_tensors(record)
 
@@ -1366,6 +1393,17 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
     # --------------------------------------------------------------------
 
+    # run_ae_test(
+    # model,
+    # alien_quiz_machine,
+    # n_epoch,
+    # c_quizzes=None,
+    # local_device=main_device,
+    # prefix="alien",
+    # )
+
+    # exit(0)
+
     # one_ae_epoch(models[0], quiz_machine, n_epoch, main_device)
     # exit(0)