- img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1)
- h = img_prompts.size(2)
- img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1)
-
- img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True)
- img_answers = add_frame(img_answers, c=255, margin=margin, bottom=True)
-
- img_prompts = add_frame(
- img_prompts, c=predicted_prompts, margin=margin, bottom=True
+ img_prompts = torch.cat(
+ [
+ add_frame(
+ add_frame(self.frame2img(x), c=0, margin=1),
+ c=predicted_prompts,
+ margin=margin,
+ )
+ for x in prompts.to("cpu").split(split_size=self.width, dim=2)
+ ],
+ dim=3,