Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 27 Jul 2024 21:55:57 +0000 (23:55 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 27 Jul 2024 21:55:57 +0000 (23:55 +0200)
grids.py
main.py
mygpt.py
quiz_machine.py

index d41ec49..296c23a 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -722,7 +722,7 @@ class Grids(problem.Problem):
 
         return no, nq, nq_diag
 
-    def task_count(self, A, f_A, B, f_B):
+    def REMOVED_task_count(self, A, f_A, B, f_B):
         while True:
             error = False
 
@@ -1022,7 +1022,7 @@ class Grids(problem.Problem):
                 return dist * (1 - walls)
 
     # @torch.compile
-    def task_distance(self, A, f_A, B, f_B):
+    def REMOVED_task_distance(self, A, f_A, B, f_B):
         c = torch.randperm(len(self.colors) - 1)[:3] + 1
         dist0 = torch.empty(self.height + 2, self.width + 2)
         dist1 = torch.empty(self.height + 2, self.width + 2)
@@ -1085,7 +1085,7 @@ class Grids(problem.Problem):
     # if
 
     # @torch.compile
-    def task_puzzle(self, A, f_A, B, f_B):
+    def TOO_HARD_task_puzzle(self, A, f_A, B, f_B):
         S = 4
         i0, j0 = (self.height - S) // 2, (self.width - S) // 2
         c = torch.randperm(len(self.colors) - 1)[:4] + 1
@@ -1153,7 +1153,7 @@ class Grids(problem.Problem):
                         if f_X[i + i0, j + j0] == c[d]:
                             X[ii + i, jj + j] = c[d]
 
-    def task_islands(self, A, f_A, B, f_B):
+    def TOO_MESSY_task_islands(self, A, f_A, B, f_B):
         c = torch.randperm(len(self.colors) - 1)[:2] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             if not hasattr(self, "cache_islands") or len(self.cache_islands) == 0:
@@ -1183,7 +1183,7 @@ class Grids(problem.Problem):
             X[i, j] = c[1]
 
     # @torch.compile
-    def task_stack(self, A, f_A, B, f_B):
+    def TOO_HARD_task_stack(self, A, f_A, B, f_B):
         N = 5
         c = torch.randperm(len(self.colors) - 1)[:N] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
@@ -1228,7 +1228,7 @@ class Grids(problem.Problem):
         m = torch.tensor(m)
         return (torch.rand(m.size()) * m).long()
 
-    def task_matrices(self, A, f_A, B, f_B):
+    def TOO_HARD_task_matrices(self, A, f_A, B, f_B):
         N = 6
         c = torch.randperm(len(self.colors) - 1)[:N] + 1
 
@@ -1244,7 +1244,7 @@ class Grids(problem.Problem):
                     f_X[i, j + 5] = c[M2[i, j]]
                     f_X[i + 5, j + 5] = c[P[i, j]]
 
-    def task_compute(self, A, f_A, B, f_B):
+    def TOO_HARD_task_compute(self, A, f_A, B, f_B):
         N = 6
         c = torch.randperm(len(self.colors) - 1)[:N] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
@@ -1423,7 +1423,7 @@ class Grids(problem.Problem):
                 if accept_full or (d * (X == 0)).max() == self.height * self.width:
                     break
 
-    def task_addition(self, A, f_A, B, f_B):
+    def TOO_HARD_task_addition(self, A, f_A, B, f_B):
         c = torch.randperm(len(self.colors) - 1)[:4] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             N1 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item()
@@ -1654,24 +1654,46 @@ if __name__ == "__main__":
 
     for t in [grids.task_science_tag]:
         print(t.__name__)
-        quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
+        w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
         grids.save_quizzes_as_image(
             "/tmp",
             t.__name__ + ".png",
-            quizzes,
-            comments=[f"{t.__name__} #{k}" for k in range(quizzes.size(0))],
+            w_quizzes,
+            comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))],
         )
 
-    exit(0)
+    exit(0)
 
     nb = 1000
 
-    # for t in grids.all_tasks:
-    for t in [grids.task_path]:
+    for t in [
+        grids.task_addition,
+        grids.task_bounce,
+        grids.task_compute,
+        grids.task_contact,
+        grids.task_corners,
+        grids.task_detect,
+        grids.task_fill,
+        grids.task_frame,
+        grids.task_grow,
+        grids.task_half_fill,
+        grids.task_islands,
+        grids.task_isometry,
+        grids.task_path,
+        grids.task_puzzle,
+        grids.task_replace_color,
+        grids.task_scale,
+        grids.task_stack,
+        grids.task_symbols,
+        grids.task_trajectory,
+        grids.task_translate,
+    ]:
+        # for t in [grids.task_path]:
         start_time = time.perf_counter()
         w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
         delay = time.perf_counter() - start_time
         print(f"{t.__name__} {w_quizzes.size(0)/delay:02f} seq/s")
+        grids.save_quizzes_as_image("/tmp", t.__name__ + ".png", w_quizzes[:128])
 
     exit(0)
 
diff --git a/main.py b/main.py
index bf617b5..fbc6f42 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -105,6 +105,8 @@ parser.add_argument("--c_quiz_validation_mode", type=str, default="predict")
 
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
+parser.add_argument("--autoencoder_dim", type=int, default=-1)
+
 ######################################################################
 
 grids_tasks = ", ".join(
@@ -449,18 +451,25 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
 ######################################################################
 
+lt_noisy = lambda s, logits: logits / args.temperature_hot
+lt_clean = lambda s, logits: logits / args.temperature_cold
+
+c_quizzes_procedure_ = [
+    (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), lt_noisy),
+    (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), lt_clean),
+    (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), lt_clean),
+]
+
+c_quizzes_procedure = [
+    (("A", "f_A", "B", "f_B"), (1, 1, 0, 0), lt_noisy),
+    (("A", "f_A", "B", "f_B"), (0, 0, 1, 1), lt_clean),
+]
+
 
 def save_additional_results(models, science_w_quizzes):
     for model in models:
         c_quizzes = quiz_machine.generate_c_quizzes(
-            128,
-            model_for_generation=model,
-            temperature_hot=args.temperature_hot,
-            temperature_cold=args.temperature_cold,
-        )
-
-        c_quizzes = quiz_machine.problem.reconfigure(
-            c_quizzes, ("A", "f_A", "B", "f_B")
+            128, model_for_generation=model, procedure=c_quizzes_procedure
         )
 
         quiz_machine.problem.save_quizzes_as_image(
@@ -541,10 +550,9 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
         c_quizzes = quiz_machine.generate_c_quizzes(
             nb_to_generate_per_iteration,
-            model_for_generation=model_for_generation,
-            temperature_hot=args.temperature_hot,
-            temperature_cold=args.temperature_cold,
-            to_recycle=to_recycle,
+            model_for_generation=model,
+            procedure=c_quizzes_procedure,
+            # to_recycle=to_recycle,
         )
 
         # We discard the trivial ones, according to a criterion
@@ -662,7 +670,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 ######################################################################
 
 
-def train_auto_encoder():
+def train_autoencoder():
     model = mygpt.MyGPT(
         vocabulary_size=vocabulary_size,
         dim_model=args.dim_model,
@@ -672,10 +680,9 @@ def train_auto_encoder():
         nb_blocks=args.nb_blocks,
         causal=False,
         dropout=args.dropout,
+        autoencoder_dim=args.autoencoder_dim,
     ).to(main_device)
 
-    model.make_auto_encoder(auto_encoder_dim=64)
-
     test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples)
 
     optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
@@ -713,7 +720,7 @@ def train_auto_encoder():
 
         log_string(f"train_perplexity {n_epoch} model ae {train_perplexity}")
 
-        filename = f"auto_encoder.pth"
+        filename = f"autoencoder.pth"
         torch.save(
             model.state_dict(),
             os.path.join(args.result_dir, filename),
@@ -722,7 +729,7 @@ def train_auto_encoder():
 
         with torch.autograd.no_grad():
             model.eval()
-            input = test_w_quizzes[:128, -l:]
+            input = test_w_quizzes[0 * 128 : 1 * 128, -l:]
             z_shape = model.encode(mygpt.BracketedSequence(input.to(main_device)))
             logits = model.decode(z_shape).x
 
@@ -739,12 +746,32 @@ def train_auto_encoder():
                 q,
             )
 
-    return model
+            input1 = test_w_quizzes[1 * 128 : 2 * 128, -l:]
+            input2 = test_w_quizzes[2 * 128 : 3 * 128, -l:]
+            z_shape1 = model.encode(mygpt.BracketedSequence(input1.to(main_device)))
+            z_shape2 = model.encode(mygpt.BracketedSequence(input2.to(main_device)))
+            z_shape = ((z_shape1[0] + z_shape2[0]) * 0.5, z_shape1[1])
+            logits = model.decode(z_shape).x
+
+            q = logits.argmax(dim=-1)
+            # q = q.reshape(q.size(0) // 2, 2, -1)
+            # input = input.reshape(input.size(0) // 2, 2, -1)
+            # q = torch.cat([input.to("cpu"), q.to("cpu")], dim=1).reshape(q.size(0), -1)
+
+            q = q.reshape(q.size(0) // 4, -1)
+
+            quiz_machine.problem.save_quizzes_as_image(
+                args.result_dir,
+                f"culture_mix_ae_{n_epoch:04d}.png",
+                q,
+            )
 
+    return model
 
-# ae = train_auto_encoder()
 
-# exit(0)
+if args.autoencoder_dim > 0:
+    ae = train_autoencoder()
+    exit(0)
 
 ######################################################################
 
index b38cc99..fca2067 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -255,6 +255,7 @@ class MyGPT(nn.Module):
         nb_heads,
         nb_blocks,
         causal=False,
+        autoencoder_dim=-1,
         dropout=0.0,
         len_max=1e5,
     ):
@@ -303,6 +304,26 @@ class MyGPT(nn.Module):
             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
         )
 
+        # -------------------------------------------------------
+        if autoencoder_dim > 0:
+            self.encoder = nn.Sequential(
+                *(
+                    trunk_blocks[: nb_blocks // 2]
+                    + [EncoderHead(dim_model, autoencoder_dim)]
+                )
+            )
+
+            self.decoder = nn.Sequential(
+                *(
+                    [
+                        DecoderBottom(autoencoder_dim, dim_model),
+                        AddPositionalEncoding(len_max),
+                    ]
+                    + trunk_blocks[nb_blocks // 2 :]
+                )
+            )
+        # -------------------------------------------------------
+
         with torch.no_grad():
             for m in self.modules():
                 if isinstance(m, nn.Embedding):
@@ -318,24 +339,6 @@ class MyGPT(nn.Module):
         bs = self.readout(bs)
         return bs
 
-    def make_auto_encoder(self, auto_encoder_dim):
-        self.encoder = nn.Sequential(
-            *(
-                trunk_blocks[: nb_blocks // 2]
-                + [EncoderHead(dim_model, auto_encoder_dim)]
-            )
-        )
-
-        self.decoder = nn.Sequential(
-            *(
-                [
-                    DecoderBottom(auto_encoder_dim, dim_model),
-                    AddPositionalEncoding(len_max),
-                ]
-                + trunk_blocks[nb_blocks // 2 :]
-            )
-        )
-
     def encode(self, bs):
         bs = self.embedding(bs)
         z = self.encoder(bs)
index 083b50e..d4b463b 100755 (executable)
@@ -433,7 +433,34 @@ class QuizMachine:
 
     ###############################################################
 
-    def generate_c_quizzes(
+    def generate_c_quizzes(self, nb, model_for_generation, procedure):
+        seq_logproba = torch.zeros(nb, device=self.device)
+
+        c_quizzes = None
+
+        for s, m, t in procedure:
+            if c_quizzes is None:
+                c_quizzes = self.problem.create_empty_quizzes(nb, s)
+                c_quizzes = c_quizzes.to(self.device)
+            elif s != pred_s:
+                c_quizzes = self.problem.reconfigure(c_quizzes, s)
+            pred_s = s
+
+            self.autoregression(
+                model=model_for_generation,
+                input=c_quizzes,
+                ar_mask=self.make_ar_mask(c_quizzes, s, m),
+                seq_logproba=seq_logproba,
+                logit_transformer=t,
+            )
+
+            c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B"))
+
+        return c_quizzes.to("cpu")
+
+    ######################################################################
+
+    def generate_c_quizzes_orig(
         self,
         nb,
         model_for_generation,