+ if direction == 0:
+ X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c
+ f_X[i1:i2, j1:j2] = c
+ else:
+ X[i1:i2, j1:j2] = c
+ f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c
+
+ def task_frame(self, A, f_A, B, f_B):
+ direction = torch.randint(2, (1,))
+ for n, X, f_X in [(1, A, f_A), (1, B, f_B)]:
+ c = torch.randperm(len(self.colors) - 1)[:1] + 1
+ for _ in range(torch.randint(n, (1,)) + 1):
+ while True:
+ i1, j1, i2, j2 = self.rec_coo(X)
+ if i1 + 3 < i2 and j1 + 3 < j2:
+ break
+ X[i1:i2, j1:j2] = c