From 64126bc4886ed19754c3406a03bc18ef40367e7e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 1 Sep 2024 12:38:55 +0200 Subject: [PATCH] Update. --- grids.py | 12 +++++++++++- main.py | 9 ++++----- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/grids.py b/grids.py index 98a0581..9441811 100755 --- a/grids.py +++ b/grids.py @@ -359,6 +359,7 @@ class Grids(problem.Problem): comment_height=48, nrow=4, margin=8, + delta=False, ): quizzes = quizzes.to("cpu") @@ -389,6 +390,10 @@ class Grids(problem.Problem): device=quizzes.device, ) + if delta: + u = (B != f_B).long() + img_delta = self.add_frame(self.grid2img(u), frame[None, :], thickness=1) + img_A = self.add_frame(self.grid2img(A), frame[None, :], thickness=1) img_f_A = self.add_frame(self.grid2img(f_A), frame[None, :], thickness=1) img_B = self.add_frame(self.grid2img(B), frame[None, :], thickness=1) @@ -428,7 +433,12 @@ class Grids(problem.Problem): img_B = self.add_frame(img_B, white[None, :], thickness=2) img_f_B = self.add_frame(img_f_B, white[None, :], thickness=2) - img = torch.cat([img_A, img_f_A, img_B, img_f_B], dim=3) + if delta: + img_delta = self.add_frame(img_delta, colors[:, 0], thickness=8) + img_delta = self.add_frame(img_delta, white[None, :], thickness=2) + img = torch.cat([img_A, img_f_A, img_B, img_f_B, img_delta], dim=3) + else: + img = torch.cat([img_A, img_f_A, img_B, img_f_B], dim=3) if comments is not None: comment_img = [text_img(comment_height, img.size(3), t) for t in comments] diff --git a/main.py b/main.py index 58d6287..2731f25 100755 --- a/main.py +++ b/main.py @@ -1357,8 +1357,8 @@ def generate_ae_c_quizzes(models, local_device=main_device): duration_max = 4 * 3600 - wanted_nb = 128 # 0000 - nb_to_save = 128 + wanted_nb = 16 # 0000 + nb_to_save = 16 with torch.autograd.no_grad(): records = [[] for _ in criteria] @@ -1394,9 +1394,7 @@ def generate_ae_c_quizzes(models, local_device=main_device): duration = time.perf_counter() - start_time - log_string( - f"generate_c_quizz_generation_speed {int(3600 * wanted_nb / duration)}/h" - ) + log_string(f"generate_c_quizz_speed {int(3600 * wanted_nb / duration)}/h") for n, u in enumerate(records): quizzes = torch.cat(u, dim=0)[:nb_to_save] @@ -1418,6 +1416,7 @@ def generate_ae_c_quizzes(models, local_device=main_device): # predicted_parts=predicted_parts, # correct_parts=correct_parts, comments=comments, + delta=True, ) log_string(f"wrote {filename}") -- 2.39.5