def task_replace_color(self, A, f_A, B, f_B):
c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1
- i1, j1, i2, j2 = self.rec_coo(A)
- A[i1:i2, j1:j2] = c1
- f_A[i1:i2, j1:j2] = c2
- for _ in range(3):
- i1, j1, i2, j2 = self.rec_coo(B)
- B[i1:i2, j1:j2] = c1
- f_B[i1:i2, j1:j2] = c2
-
- def move_color(self, A, f_A, B, f_B):
- c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1
-
- i1, j1, i2, j2 = self.rec_coo(A)
- A[i1:i2, j1:j2] = c1
- f_A[i1:i2, j1:j2] = c1
-
- while True:
- i1, j1, i2, j2 = self.rec_coo(A)
- if i2 < self.height - 1:
- break
- A[i1:i2, j1:j2] = c2
- f_A[i1:i2, j1:j2] = c2
+ for n, X, f_X in [(1, A, f_A), (3, B, f_B)]:
+ for _ in range(torch.randint(n, (1,)) + 1):
+ i1, j1, i2, j2 = self.rec_coo(X)
+ X[i1:i2, j1:j2] = c1
+ f_X[i1:i2, j1:j2] = c2
+
+ def task_move(self, A, f_A, B, f_B):
+ c = torch.randperm(len(self.colors) - 1)[:1] + 1
+ di, dj = torch.randint(2, (2,)) * 2 - 1
+ for n, X, f_X in [(1, A, f_A), (3, B, f_B)]:
+ for _ in range(torch.randint(n, (1,)) + 1):
+ while True:
+ i1, j1, i2, j2 = self.rec_coo(X)
+ if (
+ i1 + di >= 0
+ and i2 + di < X.size(0)
+ and j1 + dj >= 0
+ and j2 + dj < X.size(1)
+ ):
+ break
+
+ X[i1:i2, j1:j2] = c
+ f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c
+
+ def task_grow(self, A, f_A, B, f_B):
+ c = torch.randperm(len(self.colors) - 1)[:1] + 1
+
+ for n, X, f_X in [(1, A, f_A), (3, B, f_B)]:
+ for _ in range(torch.randint(n, (1,)) + 1):
+ while True:
+ i1, j1, i2, j2 = self.rec_coo(X)
+ if i1 + 3 < i2 and j1 + 3 < j2:
+ break
+
+ X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c
+ f_X[i1:i2, j1:j2] = c
def generate_prompts_and_answers(self, nb):
+ tasks = [self.task_replace_color, self.task_move, self.task_grow]
prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
w = self.width
f_A = prompt[:, 1 * w : 2 * w]
B = prompt[:, 2 * w : 3 * w]
f_B = answer
- # self.task_replace_color(A, f_A, B, f_B)
- self.move_color(A, f_A, B, f_B)
+ tasks[torch.randint(len(tasks), (1,))](A, f_A, B, f_B)
return prompts.flatten(1), answers.flatten(1)
def save_quizzes(
lang = Lang(nb_iterations=4)
- prompts, answers = lang.generate_prompts_and_answers(24)
+ prompts, answers = lang.generate_prompts_and_answers(36)
predicted_prompts = torch.rand(prompts.size(0)) < 0.5
predicted_answers = torch.logical_not(predicted_prompts)
lang.save_quizzes(
- "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers
+ "/tmp",
+ "test",
+ prompts,
+ answers, # predicted_prompts, predicted_answers
)
# start_time = time.perf_counter()