From 6f61e1a634805275b2378fc04d4d2d7594201f60 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 24 Jul 2024 14:04:08 +0200 Subject: [PATCH] Update. --- grids.py | 74 +++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 52 insertions(+), 22 deletions(-) diff --git a/grids.py b/grids.py index 131f85c..eaba99a 100755 --- a/grids.py +++ b/grids.py @@ -122,12 +122,20 @@ class Grids(problem.Problem): S = self.height * self.width return ( - (quizzes[:, 0 * (S + 1)] == self.l2tok(struct[0])) - & (quizzes[:, 1 * (S + 1)] == self.l2tok(struct[1])) - & (quizzes[:, 2 * (S + 1)] == self.l2tok(struct[2])) - & (quizzes[:, 3 * (S + 1)] == self.l2tok(struct[3])) + (quizzes[:, 0 * (S + 1)] == self.l2tok[struct[0]]) + & (quizzes[:, 1 * (S + 1)] == self.l2tok[struct[1]]) + & (quizzes[:, 2 * (S + 1)] == self.l2tok[struct[2]]) + & (quizzes[:, 3 * (S + 1)] == self.l2tok[struct[3]]) ).all() + def get_structure(self, quizzes): + S = self.height * self.width + struct = tuple( + self.tok2l[n.item()] for n in quizzes.reshape(-1, 4, S + 1)[0, :, 0] + ) + self.check_structure(quizzes, struct) + return struct + def make_ar_mask(self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)): assert check_structure(quizzes, struct) @@ -141,24 +149,20 @@ class Grids(problem.Problem): return ar_mask - def reconfigure( - self, - quizzes, - struct_from=("A", "f_A", "B", "f_B"), - struct_to=("f_B", "A", "f_A", "B"), - ): - assert check_structure(quizzes, struct_from) + def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")): + S = self.height * self.width + struct_from = self.get_structure(quizzes) sf = dict((l, n) for n, l in enumerate(struct_from)) result = quizzes.new(quizzes.size()) - q = quizzes.reshape(-1, 4, 4 * (S + 1)) - r = reshape.reshape(-1, 4, 4 * (S + 1)) + q = quizzes.reshape(-1, 4, S + 1) + r = result.reshape(-1, 4, S + 1) - r[:, 0, :] = q[:, sf[struct_to[0]]] - r[:, 1, :] = q[:, sf[struct_to[1]]] - r[:, 2, :] = q[:, sf[struct_to[2]]] - r[:, 3, :] = q[:, sf[struct_to[3]]] + r[:, 0] = q[:, sf[struct[0]], :] + r[:, 1] = q[:, sf[struct[1]], :] + r[:, 2] = q[:, sf[struct[2]], :] + r[:, 3] = q[:, sf[struct[3]], :] return result @@ -175,6 +179,7 @@ class Grids(problem.Problem): self.token_f_A = self.token_A + 1 self.token_B = self.token_f_A + 1 self.token_f_B = self.token_B + 1 + self.l2tok = { "A": self.token_A, "f_A": self.token_f_A, @@ -182,6 +187,13 @@ class Grids(problem.Problem): "f_B": self.token_f_B, } + self.tok2l = { + self.token_A: "A", + self.token_f_A: "f_A", + self.token_B: "B", + self.token_f_B: "f_B", + } + self.nb_token_values = self.token_f_B + 1 self.height = 10 @@ -302,10 +314,10 @@ class Grids(problem.Problem): + (1 - predicted_parts[:, :, None]) * white[None, None, :] ) - img_A = self.add_frame(img_A, colors[:, 0], thickness=6) - img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=6) - img_B = self.add_frame(img_B, colors[:, 2], thickness=6) - img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=6) + img_A = self.add_frame(img_A, colors[:, 0], thickness=8) + img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=8) + img_B = self.add_frame(img_B, colors[:, 2], thickness=8) + img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=8) img_A = self.add_frame(img_A, white[None, :], thickness=2) img_f_A = self.add_frame(img_f_A, white[None, :], thickness=2) @@ -1377,6 +1389,12 @@ class Grids(problem.Problem): total=prompts.size(0), ) + quizzes[...] = 0 + quizzes[:, 0 * (S + 1)] = self.token_A + quizzes[:, 1 * (S + 1)] = self.token_f_A + quizzes[:, 2 * (S + 1)] = self.token_B + quizzes[:, 3 * (S + 1)] = self.token_f_B + for quiz in quizzes: q = quiz.reshape(4, S + 1)[:, 1:].reshape(4, self.height, self.width) q[...] = 0 @@ -1404,6 +1422,13 @@ if __name__ == "__main__": # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4) grids = Grids() + nb = 5 + quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill]) + print(grids.get_structure(quizzes)) + blah = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B")) + print(grids.get_structure(blah)) + exit(0) + # nb = 1000 # grids = problem.MultiThreadProblem( # grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1 @@ -1424,9 +1449,14 @@ if __name__ == "__main__": # for t in [grids.task_symbols]: print(t.__name__) quizzes = grids.generate_w_quizzes_(nb, tasks=[t]) + print(grids.get_structure(quizzes)) predicted_parts = quizzes.new_zeros(quizzes.size(0), 4) - predicted_parts[:, 3] = 1 + predicted_parts[:, 3] = torch.randint( + 2, (quizzes.size(0),), device=quizzes.device + ) + predicted_parts[:, :3] = 1 - predicted_parts[:, 3:] correct_parts = torch.randint(2, (quizzes.size(0), 4), device=quizzes.device) + correct_parts[:, 1:2] = correct_parts[:, :1] grids.save_quizzes_as_image( "/tmp", t.__name__ + ".png", -- 2.39.5