From 00f321c2e2a9b7be1edcb1453bf0d45f52e50919 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 3 Jul 2024 13:15:08 +0300 Subject: [PATCH] Update. --- sky.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/sky.py b/sky.py index e509fb7..6ef8a3a 100755 --- a/sky.py +++ b/sky.py @@ -229,6 +229,7 @@ class Sky(problem.Problem): margin = 4 img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1) + h = img_prompts.size(2) img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1) img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True) @@ -241,7 +242,7 @@ class Sky(problem.Problem): img_answers, c=predicted_answers, margin=margin, bottom=True ) - marker_size = 8 + marker_size = 16 separator = img_prompts.new_full( ( @@ -253,17 +254,20 @@ class Sky(problem.Problem): 255, ) - for k in range(2, 2 * marker_size - 3): - i = k + 1 - marker_size - j = marker_size - 2 - abs(k - marker_size + 1) - separator[:, :, separator.size(2) // 2 + i, j] = 0 - separator[:, :, separator.size(2) // 2 + i + 1, j] = 0 + separator[:, :, 0] = 0 + separator[:, :, h - 1] = 0 + + for k in range(1, 2 * marker_size - 8): + i = k - (marker_size - 4) + j = marker_size - 5 - abs(i) + separator[:, :, h // 2 - 1 + i, 2 + j] = 0 + separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0 img = torch.cat([img_prompts, separator, img_answers], dim=3) image_name = os.path.join(result_dir, filename) torchvision.utils.save_image( - img.float() / 255.0, image_name, nrow=6, padding=margin * 2, pad_value=1.0 + img.float() / 255.0, image_name, nrow=6, padding=margin * 4, pad_value=1.0 ) ###################################################################### @@ -278,6 +282,7 @@ class Sky(problem.Problem): prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1) answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1) + # warnings.warn("dirty test with longer answer", RuntimeWarning) # answers = torch.cat( # [ -- 2.39.5