- for k in range(2, upscale - 2):
- direction_symbol[
- n, :, (height * upscale) // 2 - upscale // 2 + k, k
- ] = 0
- direction_symbol[
- n,
- :,
- (height * upscale) // 2 - upscale // 2 + k,
- upscale - 1 - k,
- ] = 0
-
- return torch.cat(
- [
- frame2img(f_first, height, width, upscale),
- separator,
- direction_symbol,
- separator,
- frame2img(f_second, height, width, upscale),
- ],
- dim=3,
+ c = c.long()[:, None]
+ c = c * torch.tensor([192, 192, 192], device=c.device) + (
+ 1 - c
+ ) * torch.tensor([255, 255, 255], device=c.device)
+ y[...] = c[:, :, None, None]
+ y[:, :, margin:-margin, margin:-margin] = x
+ return y
+
+ margin = 4
+
+ img_prompts = add_frame(self.frame2img(prompts.to("cpu")), 0, 1)
+ img_answers = add_frame(self.frame2img(answers.to("cpu")), 0, 1)
+
+ # img_prompts = add_frame(img_prompts, 255, margin)
+ # img_answers = add_frame(img_answers, 255, margin)
+
+ img_prompts = add_frame(img_prompts, predicted_prompts, margin)
+ img_answers = add_frame(img_answers, predicted_answers, margin)
+
+ separator = img_prompts.new_full(
+ (img_prompts.size(0), img_prompts.size(1), img_prompts.size(2), margin), 255