From 0fed14cc48d4b6b5337938a33ec06c03e1da335a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 24 Jul 2024 21:04:21 +0200 Subject: [PATCH] Update. --- grids.py | 82 +++++++++++++++++++++++++------------------------ quiz_machine.py | 9 +++--- 2 files changed, 46 insertions(+), 45 deletions(-) diff --git a/grids.py b/grids.py index 80b8b1d..ee3a1e6 100755 --- a/grids.py +++ b/grids.py @@ -264,8 +264,8 @@ class Grids(problem.Problem): y = y[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale) y = y.reshape(s[0], s[1], s[2] * scale, s[3] * scale) - y[:, :, :, torch.arange(0, y.size(3), scale)] = 224 - y[:, :, torch.arange(0, y.size(2), scale), :] = 224 + y[:, :, :, torch.arange(0, y.size(3), scale)] = 64 + y[:, :, torch.arange(0, y.size(2), scale), :] = 64 for n in range(m.size(0)): for i in range(m.size(1)): @@ -312,15 +312,15 @@ class Grids(problem.Problem): .permute(1, 0, 2, 3) ) - black, white, gray, green, red = torch.tensor( - [[0, 0, 0], [255, 255, 255], [200, 200, 200], [0, 255, 0], [255, 0, 0]], + frame, white, gray, green, red = torch.tensor( + [[64, 64, 64], [255, 255, 255], [200, 200, 200], [0, 255, 0], [255, 0, 0]], device=quizzes.device, ) - img_A = self.add_frame(self.grid2img(A), black[None, :], thickness=1) - img_f_A = self.add_frame(self.grid2img(f_A), black[None, :], thickness=1) - img_B = self.add_frame(self.grid2img(B), black[None, :], thickness=1) - img_f_B = self.add_frame(self.grid2img(f_B), black[None, :], thickness=1) + img_A = self.add_frame(self.grid2img(A), frame[None, :], thickness=1) + img_f_A = self.add_frame(self.grid2img(f_A), frame[None, :], thickness=1) + img_B = self.add_frame(self.grid2img(B), frame[None, :], thickness=1) + img_f_B = self.add_frame(self.grid2img(f_B), frame[None, :], thickness=1) # predicted_parts Nx4 # correct_parts Nx4 @@ -328,12 +328,14 @@ class Grids(problem.Problem): if predicted_parts is None: colors = white[None, None, :].expand(-1, 4, -1) else: + predicted_parts = predicted_parts.to("cpu") if correct_parts is None: colors = ( predicted_parts[:, :, None] * gray[None, None, :] + (1 - predicted_parts[:, :, None]) * white[None, None, :] ) else: + correct_parts = correct_parts.to("cpu") colors = ( predicted_parts[:, :, None] * ( @@ -1448,30 +1450,31 @@ if __name__ == "__main__": import time # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4) + grids = Grids() - nb = 5 - quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill]) - print(quizzes) - print(grids.get_structure(quizzes)) - quizzes = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B")) - print("DEBUG2", quizzes) - print(grids.get_structure(quizzes)) - print(quizzes) + # nb = 5 + # quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill]) + # print(quizzes) + # print(grids.get_structure(quizzes)) + # quizzes = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B")) + # print("DEBUG2", quizzes) + # print(grids.get_structure(quizzes)) + # print(quizzes) - i = torch.rand(quizzes.size(0)) < 0.5 + # i = torch.rand(quizzes.size(0)) < 0.5 - quizzes[i] = grids.reconfigure(quizzes[i], struct=("f_B", "f_A", "B", "A")) + # quizzes[i] = grids.reconfigure(quizzes[i], struct=("f_B", "f_A", "B", "A")) - j = grids.indices_select(quizzes, struct=("f_B", "f_A", "B", "A")) + # j = grids.indices_select(quizzes, struct=("f_B", "f_A", "B", "A")) - print( - i.equal(j), - grids.get_structure(quizzes[j]), - grids.get_structure(quizzes[j == False]), - ) + # print( + # i.equal(j), + # grids.get_structure(quizzes[j]), + # grids.get_structure(quizzes[j == False]), + # ) - exit(0) + # exit(0) # nb = 1000 # grids = problem.MultiThreadProblem( @@ -1488,37 +1491,36 @@ if __name__ == "__main__": nb, nrow = 128, 4 # nb, nrow = 8, 2 - # for t in grids.all_tasks: - for t in [grids.task_replace_color]: + for t in grids.all_tasks: + # for t in [grids.task_replace_color]: # for t in [grids.task_symbols]: print(t.__name__) quizzes = grids.generate_w_quizzes_(nb, tasks=[t]) - print(grids.get_structure(quizzes)) - predicted_parts = quizzes.new_zeros(quizzes.size(0), 4) - predicted_parts[:, 3] = torch.randint( - 2, (quizzes.size(0),), device=quizzes.device - ) - predicted_parts[:, :3] = 1 - predicted_parts[:, 3:] - correct_parts = torch.randint(2, (quizzes.size(0), 4), device=quizzes.device) - correct_parts[:, 1:2] = correct_parts[:, :1] + # predicted_parts = quizzes.new_zeros(quizzes.size(0), 4) + # predicted_parts[:, 3] = torch.randint( + # 2, (quizzes.size(0),), device=quizzes.device + # ) + # predicted_parts[:, :3] = 1 - predicted_parts[:, 3:] + # correct_parts = torch.randint(2, (quizzes.size(0), 4), device=quizzes.device) + # correct_parts[:, 1:2] = correct_parts[:, :1] grids.save_quizzes_as_image( "/tmp", t.__name__ + ".png", quizzes, - predicted_parts=predicted_parts, - correct_parts=correct_parts, + # predicted_parts=predicted_parts, + # correct_parts=correct_parts, ) - exit(0) + # exit(0) nb = 1000 for t in grids.all_tasks: # for t in [grids.task_compute]: start_time = time.perf_counter() - prompts, answers = grids.generate_w_quizzes_(nb, tasks=[t]) + w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t]) delay = time.perf_counter() - start_time - print(f"{t.__name__} {prompts.size(0)/delay:02f} seq/s") + print(f"{t.__name__} {w_quizzes.size(0)/delay:02f} seq/s") exit(0) diff --git a/quiz_machine.py b/quiz_machine.py index a384377..749ae8b 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -214,7 +214,7 @@ class QuizMachine: input = input.to(self.device) result = input.new(input.size()) correct = torch.empty(input.size(0), device=input.device, dtype=torch.bool) - + predicted_parts = input.new(input.size(0), 4) nb = 0 for struct, mask in [ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1)), @@ -226,8 +226,8 @@ class QuizMachine: result[i], correct[i] = self.predict( model=model, quizzes=input[i], struct=struct, mask=mask ) + predicted_parts[i] = torch.tensor(mask, device=self.device)[None, :] - print(f"{nb=} {input.size(0)=}") assert nb == input.size(0) main_test_accuracy = correct.sum() / correct.size(0) @@ -238,6 +238,7 @@ class QuizMachine: result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}.png", quizzes=result[:128], + predicted_parts=predicted_parts, ) return main_test_accuracy @@ -266,8 +267,6 @@ class QuizMachine: model.test_w_quizzes, configurations=self.configurations ) - # print(model.train_w_quizzes.sum()) - ###################################################################### def renew_train_w_quizzes(self, model): @@ -368,7 +367,7 @@ class QuizMachine: seq_logproba[...] = 0.0 c_quizzes = c_quizzes.to(self.device) - print(self.problem.get_structure(c_quizzes)) + reversed_c_quizzes = self.problem.reconfigure( c_quizzes, ("f_A", "A", "f_B", "B") ) -- 2.39.5