predicted_prompts=None,
predicted_answers=None,
):
+ prompts = prompts.reshape(prompts.size(0), self.height, -1)
+ answers = answers.reshape(answers.size(0), self.height, -1)
+
if predicted_prompts is None:
predicted_prompts = 255
predicted_answers = 255
def add_frame(x, c, margin, bottom=False):
- print(f"{type(x)=} {type(c)=}")
if bottom:
h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
else:
break
return i1, j1, i2, j2
- def task_red_to_green(self, A, f_A, B, f_B):
+ def task_replace_color(self, A, f_A, B, f_B):
+ c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1
+ i1, j1, i2, j2 = self.rec_coo(A)
+ A[i1:i2, j1:j2] = c1
+ f_A[i1:i2, j1:j2] = c2
+ for _ in range(3):
+ i1, j1, i2, j2 = self.rec_coo(B)
+ B[i1:i2, j1:j2] = c1
+ f_B[i1:i2, j1:j2] = c2
+
+ def move_color(self, A, f_A, B, f_B):
+ c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1
+
i1, j1, i2, j2 = self.rec_coo(A)
- A[i1:i2, j1:j2] = self.name2color["red"]
- f_A[i1:i2, j1:j2] = self.name2color["green"]
- i1, j1, i2, j2 = self.rec_coo(B)
- B[i1:i2, j1:j2] = self.name2color["red"]
- f_B[i1:i2, j1:j2] = self.name2color["green"]
+ A[i1:i2, j1:j2] = c1
+ f_A[i1:i2, j1:j2] = c1
+
+ while True:
+ i1, j1, i2, j2 = self.rec_coo(A)
+ if i2 < self.height - 1:
+ break
+ A[i1:i2, j1:j2] = c2
+ f_A[i1:i2, j1:j2] = c2
def generate_prompts_and_answers(self, nb):
prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
w = self.width
for prompt, answer in zip(prompts, answers):
- self.task_red_to_green(
- prompt[:, 0 * w : 1 * w],
- prompt[:, 1 * w : 2 * w],
- prompt[:, 2 * w : 3 * w],
- answer,
- )
- return prompts, answers
+ A = prompt[:, 0 * w : 1 * w]
+ f_A = prompt[:, 1 * w : 2 * w]
+ B = prompt[:, 2 * w : 3 * w]
+ f_B = answer
+ # self.task_replace_color(A, f_A, B, f_B)
+ self.move_color(A, f_A, B, f_B)
+ return prompts.flatten(1), answers.flatten(1)
def save_quizzes(
self,