From ea8c7e5bc9cd29f4cc297a010cdf0eafc626c2a0 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 3 Aug 2024 13:15:44 +0200 Subject: [PATCH] Update. --- grids.py | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ main.py | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 103 insertions(+), 2 deletions(-) diff --git a/grids.py b/grids.py index 52db2b5..05c3057 100755 --- a/grids.py +++ b/grids.py @@ -574,6 +574,53 @@ class Grids(problem.Problem): ###################################################################### + def sample_rworld_states(self, N=1000): + nb_rec_max = 5 + while True: + rn = torch.randint(nb_rec_max - 1, (N,)) + 2 + ri = torch.randint(self.height, (N, nb_rec_max, 2)).sort(dim=2).values + rj = torch.randint(self.width, (N, nb_rec_max, 2)).sort(dim=2).values + rz = torch.randint(2, (N, nb_rec_max)) + rc = torch.randint(self.nb_colors - 1, (N, nb_rec_max)) + 1 + n = torch.arange(nb_rec_max) + nb_collisions = ( + ( + (ri[:, :, None, 0] <= ri[:, None, :, 1]) + & (ri[:, :, None, 1] >= ri[:, None, :, 0]) + & (rj[:, :, None, 0] <= rj[:, None, :, 1]) + & (rj[:, :, None, 1] >= rj[:, None, :, 0]) + & (rz[:, :, None] == rz[:, None, :]) + & (n[None, :, None] < rn[:, None, None]) + & (n[None, None, :] < n[None, :, None]) + ) + .long() + .flatten(1) + .sum(dim=1) + ) + + no_collision = nb_collisions == 0 + + if no_collision.any(): + print(no_collision.long().sum() / N) + self.rn = rn[no_collision] + self.ri = ri[no_collision] + self.rj = rj[no_collision] + self.rz = rz[no_collision] + self.rc = rc[no_collision] + break + + def task_rworld_change_color(self, A, f_A, B, f_B): + nb_rec = 3 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + r = self.rec_coo(nb_rec, prevent_overlap=True) + for n in range(nb_rec): + i1, j1, i2, j2 = r[n] + X[i1:i2, j1:j2] = c[n] + f_X[i1:i2, j1:j2] = c[n if n > 0 else -1] + + ###################################################################### + # @torch.compile def task_replace_color(self, A, f_A, B, f_B): nb_rec = 3 @@ -1656,6 +1703,8 @@ if __name__ == "__main__": # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4) grids = Grids() + grids.sample_rworld_states() + exit(0) # nb = 5 # quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill]) diff --git a/main.py b/main.py index 072551e..36b58e2 100755 --- a/main.py +++ b/main.py @@ -108,7 +108,7 @@ parser.add_argument("--nb_averaging_rounds", type=int, default=3) parser.add_argument("--dirty_debug", action="store_true", default=False) -parser.add_argument("--test_generator", action="store_true", default=False) +parser.add_argument("--test", type=str, default=None) ###################################################################### @@ -1065,7 +1065,59 @@ if args.dirty_debug: ###################################################################### -if args.test_generator: +if args.test == "tsne": + model = models[0] + + quizzes = [] + labels = [] + nb_samples_per_task = 1000 + + for n, t in enumerate(args.grids_world_tasks.split(",")): + quizzes.append( + quiz_machine.problem.generate_w_quizzes(nb_samples_per_task, [t]) + ) + labels.append(torch.full((quizzes[-1].size(0),), n)) + + quizzes = torch.cat(quizzes, dim=0) + labels = torch.cat(labels, dim=0) + + with torch.autograd.no_grad(): + model.eval().to(main_device) + record = [] + for input, targets in zip( + quizzes.split(args.batch_size), labels.split(args.batch_size) + ): + input = input.to(main_device) + bs = mygpt.BracketedSequence(input) + bs = mygpt.BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) + bs = model.embedding(bs) + bs = model.trunk[args.nb_blocks // 2](bs) + record.append((bs.x.to("cpu"), targets)) + + x = torch.cat([x for x, y in record], dim=0).flatten(1) + y = torch.cat([y for x, y in record], dim=0) + + print(f"{x.size()=} {y.size()=}") + # torch.save((x,y), "/tmp/embed.pth") + # exit(0) + + from sklearn.manifold import TSNE + + x_np = x.numpy() + z_np = TSNE(n_components=2, perplexity=50).fit_transform(x_np) + z = torch.from_numpy(z_np) + + print(f"{z.size()=}") + + with open("/tmp/result.dat", "w") as f: + for k in range(z.size(0)): + f.write(f"{y[k]} {z[k,0]} {z[k,1]}\n") + + exit(0) + +###################################################################### + +if args.test == "generator": token_prolog_0 = vocabulary_size + 0 token_prolog_1 = vocabulary_size + 1 token_prolog_2 = vocabulary_size + 2 -- 2.39.5