def nb_token_values(self):
return len(self.colors)
- def rec_coo(self, x):
+ def rec_coo(self, x, n, min_height=3, min_width=3):
while True:
- i1, i2 = torch.randint(x.size(0), (2,))
- if i1 < i2 - 1:
- break
- while True:
- j1, j2 = torch.randint(x.size(1), (2,))
- if j1 < j2 - 1:
+ collision = x.new_zeros(x.size())
+ result = []
+ for _ in range(n):
+ while True:
+ i1, i2 = torch.randint(x.size(0), (2,))
+ if i1 + min_height <= i2:
+ break
+ while True:
+ j1, j2 = torch.randint(x.size(1), (2,))
+ if j1 + min_width <= j2:
+ break
+ collision[i1:i2, j1:j2] += 1
+ if collision.max() > 1:
+ break
+ result.append((i1, j1, i2, j2))
+ if collision.max() == 1:
break
- return i1, j1, i2, j2
+ return result
######################################################################
def task_replace_color(self, A, f_A, B, f_B):
- c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1
- for n, X, f_X in [(1, A, f_A), (1, B, f_B)]:
- for _ in range(torch.randint(n, (1,)) + 1):
- i1, j1, i2, j2 = self.rec_coo(X)
- X[i1:i2, j1:j2] = c1
- f_X[i1:i2, j1:j2] = c2
+ N = 3
+ c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ r = self.rec_coo(X, N)
+ for n in range(N):
+ i1, j1, i2, j2 = r[n]
+ X[i1:i2, j1:j2] = c[n]
+ f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
def task_move(self, A, f_A, B, f_B):
di, dj = torch.randint(2, (2,)) * 2 - 1
- for n, X, f_X in [(1, A, f_A), (1, B, f_B)]:
- c = torch.randperm(len(self.colors) - 1)[:1] + 1
- for _ in range(torch.randint(n, (1,)) + 1):
- while True:
- i1, j1, i2, j2 = self.rec_coo(X)
- if (
- i1 + di >= 0
- and i2 + di < X.size(0)
- and j1 + dj >= 0
- and j2 + dj < X.size(1)
- ):
- break
-
- X[i1:i2, j1:j2] = c
- f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c
+ N = 3
+ c = torch.randperm(len(self.colors) - 1)[:N] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ while True:
+ r = self.rec_coo(X, N)
+ i1, j1, i2, j2 = r[N - 1]
+ if (
+ i1 + di >= 0
+ and i2 + di < X.size(0)
+ and j1 + dj >= 0
+ and j2 + dj < X.size(1)
+ ):
+ break
+
+ for n in range(N):
+ i1, j1, i2, j2 = r[n]
+ X[i1:i2, j1:j2] = c[n]
+ if n == N - 1:
+ f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
+ else:
+ f_X[i1:i2, j1:j2] = c[n]
def task_grow(self, A, f_A, B, f_B):
+ di, dj = torch.randint(2, (2,)) * 2 - 1
+ N = 3
+ c = torch.randperm(len(self.colors) - 1)[:N] + 1
direction = torch.randint(2, (1,))
- for n, X, f_X in [(1, A, f_A), (1, B, f_B)]:
- c = torch.randperm(len(self.colors) - 1)[:1] + 1
- for _ in range(torch.randint(n, (1,)) + 1):
- while True:
- i1, j1, i2, j2 = self.rec_coo(X)
- if i1 + 3 < i2 and j1 + 3 < j2:
- break
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ while True:
+ r = self.rec_coo(X, N)
+ i1, j1, i2, j2 = r[N - 1]
+ if i1 + 3 < i2 and j1 + 3 < j2:
+ break
+
+ for n in range(N):
+ i1, j1, i2, j2 = r[n]
+ if n == N - 1:
+ if direction == 0:
+ X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
+ f_X[i1:i2, j1:j2] = c[n]
+ else:
+ X[i1:i2, j1:j2] = c[n]
+ f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
+ else:
+ X[i1:i2, j1:j2] = c[n]
+ f_X[i1:i2, j1:j2] = c[n]
- if direction == 0:
- X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c
- f_X[i1:i2, j1:j2] = c
+ def task_color_grow(self, A, f_A, B, f_B):
+ di, dj = torch.randint(2, (2,)) * 2 - 1
+ N = 3
+ c = torch.randperm(len(self.colors) - 1)[: 2 * N] + 1
+ direction = torch.randint(2, (1,))
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ r = self.rec_coo(X, N)
+ for n in range(N):
+ i1, j1, i2, j2 = r[n]
+ X[i1 : (i1 + i2) // 2, j1:j2] = c[2 * n]
+ f_X[i1 : (i1 + i2) // 2, j1:j2] = c[2 * n]
+ X[(i1 + i2) // 2 : (i1 + i2) // 2 + 1, j1:j2] = c[2 * n + 1]
+ if n == N - 1:
+ f_X[(i1 + i2) // 2 : i2, j1:j2] = c[2 * n + 1]
else:
- X[i1:i2, j1:j2] = c
- f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c
+ f_X[(i1 + i2) // 2 : (i1 + i2) // 2 + 1, j1:j2] = c[2 * n + 1]
def task_frame(self, A, f_A, B, f_B):
- direction = torch.randint(2, (1,))
- for n, X, f_X in [(1, A, f_A), (1, B, f_B)]:
- c = torch.randperm(len(self.colors) - 1)[:1] + 1
- for _ in range(torch.randint(n, (1,)) + 1):
- while True:
- i1, j1, i2, j2 = self.rec_coo(X)
- if i1 + 3 < i2 and j1 + 3 < j2:
- break
- X[i1:i2, j1:j2] = c
- f_X[i1:i2, j1:j2] = c
- f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0
+ N = 3
+ c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ r = self.rec_coo(X, N)
+ for n in range(N):
+ i1, j1, i2, j2 = r[n]
+ X[i1:i2, j1:j2] = c[n]
+ f_X[i1:i2, j1:j2] = c[n]
+ if n == N - 1:
+ f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0
######################################################################
self.task_replace_color,
self.task_move,
self.task_grow,
+ self.task_color_grow,
self.task_frame,
]
prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)