predicted_answers = 255
def add_frame(x, c, margin, bottom=False):
+ print(f"{type(x)=} {type(c)=}")
if bottom:
h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
else:
y[...] = c
else:
c = c.long()[:, None]
- c = c * torch.tensor([0, 0, 0], device=c.device) + (
+ c = c * torch.tensor([192, 192, 192], device=c.device) + (
1 - c
) * torch.tensor([255, 255, 255], device=c.device)
y[...] = c[:, :, None, None]
return y
- margin = 4
+ margin = 8
- img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1)
- h = img_prompts.size(2)
- img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1)
-
- img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True)
- img_answers = add_frame(img_answers, c=255, margin=margin, bottom=True)
-
- img_prompts = add_frame(
- img_prompts, c=predicted_prompts, margin=margin, bottom=True
+ img_prompts = torch.cat(
+ [
+ add_frame(
+ add_frame(self.frame2img(x), c=0, margin=1),
+ c=predicted_prompts,
+ margin=margin,
+ )
+ for x in prompts.to("cpu").split(split_size=self.width, dim=2)
+ ],
+ dim=3,
)
+
+ h = img_prompts.size(2)
img_answers = add_frame(
- img_answers, c=predicted_answers, margin=margin, bottom=True
+ add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
+ c=predicted_answers,
+ margin=margin,
)
- marker_size = 16
+ separator_size = 2 * margin
separator = img_prompts.new_full(
(
img_prompts.size(0),
img_prompts.size(1),
img_prompts.size(2),
- marker_size,
+ separator_size,
),
255,
)
- separator[:, :, 0] = 0
- separator[:, :, h - 1] = 0
-
- for k in range(1, 2 * marker_size - 8):
- i = k - (marker_size - 4)
- j = marker_size - 5 - abs(i)
- separator[:, :, h // 2 - 1 + i, 2 + j] = 0
- separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
+ marker = img_prompts.new_full(
+ (
+ img_prompts.size(0),
+ img_prompts.size(1),
+ img_prompts.size(2),
+ separator_size,
+ ),
+ 255,
+ )
- img = torch.cat([img_prompts, separator, img_answers], dim=3)
+ # marker[:, :, 0] = 0
+ # marker[:, :, h - 1] = 0
+
+ for k in range(1, 2 * separator_size - 8):
+ i = k - (separator_size - 4)
+ j = separator_size - 5 - abs(i)
+ marker[:, :, h // 2 - 1 + i, 2 + j] = 0
+ marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
+
+ img = torch.cat(
+ [
+ img_prompts,
+ marker,
+ img_answers,
+ ],
+ dim=3,
+ )
image_name = os.path.join(result_dir, filename)
torchvision.utils.save_image(
prompts, answers = lang.generate_prompts_and_answers(24)
- # predicted_prompts = torch.rand(prompts.size(0)) < 0.5
- # predicted_answers = torch.rand(answers.size(0)) < 0.5
+ predicted_prompts = torch.rand(prompts.size(0)) < 0.5
+ predicted_answers = torch.logical_not(predicted_prompts)
lang.save_quizzes(
- "/tmp", "test", prompts, answers # , predicted_prompts, predicted_answers
+ "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers
)
# start_time = time.perf_counter()