Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 3 Aug 2024 11:15:44 +0000 (13:15 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 3 Aug 2024 11:15:44 +0000 (13:15 +0200)
grids.py
main.py

index 52db2b5..05c3057 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -574,6 +574,53 @@ class Grids(problem.Problem):
 
     ######################################################################
 
+    def sample_rworld_states(self, N=1000):
+        nb_rec_max = 5
+        while True:
+            rn = torch.randint(nb_rec_max - 1, (N,)) + 2
+            ri = torch.randint(self.height, (N, nb_rec_max, 2)).sort(dim=2).values
+            rj = torch.randint(self.width, (N, nb_rec_max, 2)).sort(dim=2).values
+            rz = torch.randint(2, (N, nb_rec_max))
+            rc = torch.randint(self.nb_colors - 1, (N, nb_rec_max)) + 1
+            n = torch.arange(nb_rec_max)
+            nb_collisions = (
+                (
+                    (ri[:, :, None, 0] <= ri[:, None, :, 1])
+                    & (ri[:, :, None, 1] >= ri[:, None, :, 0])
+                    & (rj[:, :, None, 0] <= rj[:, None, :, 1])
+                    & (rj[:, :, None, 1] >= rj[:, None, :, 0])
+                    & (rz[:, :, None] == rz[:, None, :])
+                    & (n[None, :, None] < rn[:, None, None])
+                    & (n[None, None, :] < n[None, :, None])
+                )
+                .long()
+                .flatten(1)
+                .sum(dim=1)
+            )
+
+            no_collision = nb_collisions == 0
+
+            if no_collision.any():
+                print(no_collision.long().sum() / N)
+                self.rn = rn[no_collision]
+                self.ri = ri[no_collision]
+                self.rj = rj[no_collision]
+                self.rz = rz[no_collision]
+                self.rc = rc[no_collision]
+                break
+
+    def task_rworld_change_color(self, A, f_A, B, f_B):
+        nb_rec = 3
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            r = self.rec_coo(nb_rec, prevent_overlap=True)
+            for n in range(nb_rec):
+                i1, j1, i2, j2 = r[n]
+                X[i1:i2, j1:j2] = c[n]
+                f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
+
+    ######################################################################
+
     # @torch.compile
     def task_replace_color(self, A, f_A, B, f_B):
         nb_rec = 3
@@ -1656,6 +1703,8 @@ if __name__ == "__main__":
     # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4)
 
     grids = Grids()
+    grids.sample_rworld_states()
+    exit(0)
 
     # nb = 5
     # quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill])
diff --git a/main.py b/main.py
index 072551e..36b58e2 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -108,7 +108,7 @@ parser.add_argument("--nb_averaging_rounds", type=int, default=3)
 
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
-parser.add_argument("--test_generator", action="store_true", default=False)
+parser.add_argument("--test", type=str, default=None)
 
 ######################################################################
 
@@ -1065,7 +1065,59 @@ if args.dirty_debug:
 
 ######################################################################
 
-if args.test_generator:
+if args.test == "tsne":
+    model = models[0]
+
+    quizzes = []
+    labels = []
+    nb_samples_per_task = 1000
+
+    for n, t in enumerate(args.grids_world_tasks.split(",")):
+        quizzes.append(
+            quiz_machine.problem.generate_w_quizzes(nb_samples_per_task, [t])
+        )
+        labels.append(torch.full((quizzes[-1].size(0),), n))
+
+    quizzes = torch.cat(quizzes, dim=0)
+    labels = torch.cat(labels, dim=0)
+
+    with torch.autograd.no_grad():
+        model.eval().to(main_device)
+        record = []
+        for input, targets in zip(
+            quizzes.split(args.batch_size), labels.split(args.batch_size)
+        ):
+            input = input.to(main_device)
+            bs = mygpt.BracketedSequence(input)
+            bs = mygpt.BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
+            bs = model.embedding(bs)
+            bs = model.trunk[args.nb_blocks // 2](bs)
+            record.append((bs.x.to("cpu"), targets))
+
+    x = torch.cat([x for x, y in record], dim=0).flatten(1)
+    y = torch.cat([y for x, y in record], dim=0)
+
+    print(f"{x.size()=} {y.size()=}")
+    # torch.save((x,y), "/tmp/embed.pth")
+    # exit(0)
+
+    from sklearn.manifold import TSNE
+
+    x_np = x.numpy()
+    z_np = TSNE(n_components=2, perplexity=50).fit_transform(x_np)
+    z = torch.from_numpy(z_np)
+
+    print(f"{z.size()=}")
+
+    with open("/tmp/result.dat", "w") as f:
+        for k in range(z.size(0)):
+            f.write(f"{y[k]} {z[k,0]} {z[k,1]}\n")
+
+    exit(0)
+
+######################################################################
+
+if args.test == "generator":
     token_prolog_0 = vocabulary_size + 0
     token_prolog_1 = vocabulary_size + 1
     token_prolog_2 = vocabulary_size + 2