f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
+ def task_ortho(self, A, f_A, B, f_B):
+ nb_rec = 3
+ di, dj = torch.randint(3, (2,)) - 1
+ o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]])
+ m = torch.eye(2)
+ for _ in range(torch.randint(4, (1,))):
+ m = m @ o
+ if torch.rand(1) < 0.5:
+ m[0, :] = -m[0, :]
+
+ ci, cj = (self.height - 1) / 2, (self.width - 1) / 2
+
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ while True:
+ X[...] = 0
+ f_X[...] = 0
+
+ c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+
+ for r in range(nb_rec):
+ while True:
+ i1, i2 = torch.randint(self.height - 2, (2,)) + 1
+ j1, j2 = torch.randint(self.width - 2, (2,)) + 1
+ if (
+ i2 >= i1
+ and j2 >= j1
+ and max(i2 - i1, j2 - j1) >= 2
+ and min(i2 - i1, j2 - j1) <= 3
+ ):
+ break
+ X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
+
+ i1, j1, i2, j2 = i1 - ci, j1 - cj, i2 - ci, j2 - cj
+
+ i1, j1 = m[0, 0] * i1 + m[0, 1] * j1, m[1, 0] * i1 + m[1, 1] * j1
+ i2, j2 = m[0, 0] * i2 + m[0, 1] * j2, m[1, 0] * i2 + m[1, 1] * j2
+
+ i1, j1, i2, j2 = i1 + ci, j1 + cj, i2 + ci, j2 + cj
+ i1, i2 = i1.long() + di, i2.long() + di
+ j1, j2 = j1.long() + dj, j2.long() + dj
+ if i1 > i2:
+ i1, i2 = i2, i1
+ if j1 > j2:
+ j1, j2 = j2, j1
+
+ f_X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
+
+ n = F.one_hot(X.flatten()).sum(dim=0)[1:]
+ if (
+ n.sum() > self.height * self.width // 4
+ and (n > 0).long().sum() == nb_rec
+ ):
+ break
+
def task_islands(self, A, f_A, B, f_B):
pass
grids = Grids()
# for t in grids.all_tasks():
- for t in [grids.task_islands]:
+ for t in [grids.task_ortho]:
print(t.__name__)
prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t])
grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4)