margin = 4
img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1)
+ h = img_prompts.size(2)
img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1)
img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True)
img_answers, c=predicted_answers, margin=margin, bottom=True
)
- marker_size = 8
+ marker_size = 16
separator = img_prompts.new_full(
(
255,
)
- for k in range(2, 2 * marker_size - 3):
- i = k + 1 - marker_size
- j = marker_size - 2 - abs(k - marker_size + 1)
- separator[:, :, separator.size(2) // 2 + i, j] = 0
- separator[:, :, separator.size(2) // 2 + i + 1, j] = 0
+ separator[:, :, 0] = 0
+ separator[:, :, h - 1] = 0
+
+ for k in range(1, 2 * marker_size - 8):
+ i = k - (marker_size - 4)
+ j = marker_size - 5 - abs(i)
+ separator[:, :, h // 2 - 1 + i, 2 + j] = 0
+ separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
img = torch.cat([img_prompts, separator, img_answers], dim=3)
image_name = os.path.join(result_dir, filename)
torchvision.utils.save_image(
- img.float() / 255.0, image_name, nrow=6, padding=margin * 2, pad_value=1.0
+ img.float() / 255.0, image_name, nrow=6, padding=margin * 4, pad_value=1.0
)
######################################################################
prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1)
answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1)
+
# warnings.warn("dirty test with longer answer", RuntimeWarning)
# answers = torch.cat(
# [