+ def save_image(
+ self,
+ result_dir,
+ filename,
+ prompts,
+ answers,
+ predicted_prompts=None,
+ predicted_answers=None,
+ ):
+ if predicted_prompts is None:
+ predicted_prompts = 255
+
+ if predicted_answers is None:
+ predicted_answers = 255
+
+ def add_frame(x, c, margin):
+ y = x.new_full(
+ (x.size(0), x.size(1), x.size(2) + 2 * margin, x.size(3) + 2 * margin),
+ 0,
+ )
+ if type(c) is int:
+ y[...] = c
+ else:
+ c = c.long()[:, None]
+ c = c * torch.tensor([192, 192, 192], device=c.device) + (
+ 1 - c
+ ) * torch.tensor([255, 255, 255], device=c.device)
+ y[...] = c[:, :, None, None]
+ y[:, :, margin:-margin, margin:-margin] = x
+ return y
+
+ margin = 4
+
+ img_prompts = add_frame(self.frame2img(prompts.to("cpu")), 0, 1)
+ img_answers = add_frame(self.frame2img(answers.to("cpu")), 0, 1)
+
+ # img_prompts = add_frame(img_prompts, 255, margin)
+ # img_answers = add_frame(img_answers, 255, margin)
+
+ img_prompts = add_frame(img_prompts, predicted_prompts, margin)
+ img_answers = add_frame(img_answers, predicted_answers, margin)
+
+ separator = img_prompts.new_full(
+ (img_prompts.size(0), img_prompts.size(1), img_prompts.size(2), margin), 255
+ )
+
+ img = torch.cat([img_prompts, img_answers], dim=3)
+