predicted_prompts=None,
predicted_answers=None,
):
+ prompts = prompts.reshape(prompts.size(0), self.height, -1)
+ answers = answers.reshape(answers.size(0), self.height, -1)
+
if predicted_prompts is None:
predicted_prompts = 255
y[...] = c
else:
c = c.long()[:, None]
- c = c * torch.tensor([0, 0, 0], device=c.device) + (
+ c = c * torch.tensor([192, 192, 192], device=c.device) + (
1 - c
) * torch.tensor([255, 255, 255], device=c.device)
y[...] = c[:, :, None, None]
return y
- margin = 4
-
- img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1)
- h = img_prompts.size(2)
- img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1)
-
- img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True)
- img_answers = add_frame(img_answers, c=255, margin=margin, bottom=True)
+ margin = 8
- img_prompts = add_frame(
- img_prompts, c=predicted_prompts, margin=margin, bottom=True
+ img_prompts = torch.cat(
+ [
+ add_frame(
+ add_frame(self.frame2img(x), c=0, margin=1),
+ c=predicted_prompts,
+ margin=margin,
+ )
+ for x in prompts.to("cpu").split(split_size=self.width, dim=2)
+ ],
+ dim=3,
)
+
+ h = img_prompts.size(2)
img_answers = add_frame(
- img_answers, c=predicted_answers, margin=margin, bottom=True
+ add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
+ c=predicted_answers,
+ margin=margin,
)
- marker_size = 16
+ separator_size = 2 * margin
separator = img_prompts.new_full(
(
img_prompts.size(0),
img_prompts.size(1),
img_prompts.size(2),
- marker_size,
+ separator_size,
),
255,
)
- separator[:, :, 0] = 0
- separator[:, :, h - 1] = 0
-
- for k in range(1, 2 * marker_size - 8):
- i = k - (marker_size - 4)
- j = marker_size - 5 - abs(i)
- separator[:, :, h // 2 - 1 + i, 2 + j] = 0
- separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
+ marker = img_prompts.new_full(
+ (
+ img_prompts.size(0),
+ img_prompts.size(1),
+ img_prompts.size(2),
+ separator_size,
+ ),
+ 255,
+ )
- img = torch.cat([img_prompts, separator, img_answers], dim=3)
+ # marker[:, :, 0] = 0
+ # marker[:, :, h - 1] = 0
+
+ for k in range(1, 2 * separator_size - 8):
+ i = k - (separator_size - 4)
+ j = separator_size - 5 - abs(i)
+ marker[:, :, h // 2 - 1 + i, 2 + j] = 0
+ marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
+
+ img = torch.cat(
+ [
+ img_prompts,
+ marker,
+ img_answers,
+ ],
+ dim=3,
+ )
image_name = os.path.join(result_dir, filename)
torchvision.utils.save_image(
break
return i1, j1, i2, j2
- def task_red_to_green(self, A, f_A, B, f_B):
+ def task_replace_color(self, A, f_A, B, f_B):
+ c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1
i1, j1, i2, j2 = self.rec_coo(A)
- A[i1:i2, j1:j2] = self.name2color["red"]
- f_A[i1:i2, j1:j2] = self.name2color["green"]
- i1, j1, i2, j2 = self.rec_coo(B)
- B[i1:i2, j1:j2] = self.name2color["red"]
- f_B[i1:i2, j1:j2] = self.name2color["green"]
+ A[i1:i2, j1:j2] = c1
+ f_A[i1:i2, j1:j2] = c2
+ for _ in range(3):
+ i1, j1, i2, j2 = self.rec_coo(B)
+ B[i1:i2, j1:j2] = c1
+ f_B[i1:i2, j1:j2] = c2
+
+ def move_color(self, A, f_A, B, f_B):
+ c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1
+
+ i1, j1, i2, j2 = self.rec_coo(A)
+ A[i1:i2, j1:j2] = c1
+ f_A[i1:i2, j1:j2] = c1
+
+ while True:
+ i1, j1, i2, j2 = self.rec_coo(A)
+ if i2 < self.height - 1:
+ break
+ A[i1:i2, j1:j2] = c2
+ f_A[i1:i2, j1:j2] = c2
def generate_prompts_and_answers(self, nb):
prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
w = self.width
for prompt, answer in zip(prompts, answers):
- self.task_red_to_green(
- prompt[:, 0 * w : 1 * w],
- prompt[:, 1 * w : 2 * w],
- prompt[:, 2 * w : 3 * w],
- answer,
- )
- return prompts, answers
+ A = prompt[:, 0 * w : 1 * w]
+ f_A = prompt[:, 1 * w : 2 * w]
+ B = prompt[:, 2 * w : 3 * w]
+ f_B = answer
+ # self.task_replace_color(A, f_A, B, f_B)
+ self.move_color(A, f_A, B, f_B)
+ return prompts.flatten(1), answers.flatten(1)
def save_quizzes(
self,
prompts, answers = lang.generate_prompts_and_answers(24)
- # predicted_prompts = torch.rand(prompts.size(0)) < 0.5
- # predicted_answers = torch.rand(answers.size(0)) < 0.5
+ predicted_prompts = torch.rand(prompts.size(0)) < 0.5
+ predicted_answers = torch.logical_not(predicted_prompts)
lang.save_quizzes(
- "/tmp", "test", prompts, answers # , predicted_prompts, predicted_answers
+ "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers
)
# start_time = time.perf_counter()