+ def save_image(
+ self,
+ result_dir,
+ filename,
+ prompts,
+ answers,
+ predicted_prompts=None,
+ predicted_answers=None,
+ ):
+ if predicted_prompts is None:
+ predicted_prompts = 255
+
+ if predicted_answers is None:
+ predicted_answers = 255
+
+ def add_frame(x, c, margin, bottom=False):
+ if bottom:
+ h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
+ else:
+ h, w, di, dj = (
+ x.size(2) + 2 * margin,
+ x.size(3) + 2 * margin,
+ margin,
+ margin,
+ )
+
+ y = x.new_full((x.size(0), x.size(1), h, w), 0)
+
+ if type(c) is int:
+ y[...] = c
+ else:
+ c = c.long()[:, None]
+ c = (
+ (c == 1).long() * torch.tensor([0, 255, 0], device=c.device)
+ + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device)
+ + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device)
+ )
+ y[...] = c[:, :, None, None]
+
+ y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
+
+ return y
+
+ margin = 4
+
+ img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1)
+ h = img_prompts.size(2)
+ img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1)
+
+ img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True)
+ img_answers = add_frame(img_answers, c=255, margin=margin, bottom=True)
+
+ img_prompts = add_frame(
+ img_prompts, c=predicted_prompts, margin=margin, bottom=True
+ )
+ img_answers = add_frame(
+ img_answers, c=predicted_answers, margin=margin, bottom=True
+ )
+
+ marker_size = 16
+
+ separator = img_prompts.new_full(
+ (
+ img_prompts.size(0),
+ img_prompts.size(1),
+ img_prompts.size(2),
+ marker_size,
+ ),
+ 255,
+ )
+
+ separator[:, :, 0] = 0
+ separator[:, :, h - 1] = 0
+
+ for k in range(1, 2 * marker_size - 8):
+ i = k - (marker_size - 4)
+ j = marker_size - 5 - abs(i)
+ separator[:, :, h // 2 - 1 + i, 2 + j] = 0
+ separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
+
+ img = torch.cat([img_prompts, separator, img_answers], dim=3)
+