From 82099816ad38a4ceaa6f59395f812fe29e467bb8 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 3 Jul 2024 19:34:26 +0300 Subject: [PATCH] Update. --- lang.py | 147 ++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 95 insertions(+), 52 deletions(-) diff --git a/lang.py b/lang.py index 4c3c64f..43550d7 100755 --- a/lang.py +++ b/lang.py @@ -172,74 +172,116 @@ class Lang(problem.Problem): def nb_token_values(self): return len(self.colors) - def rec_coo(self, x): + def rec_coo(self, x, n, min_height=3, min_width=3): while True: - i1, i2 = torch.randint(x.size(0), (2,)) - if i1 < i2 - 1: - break - while True: - j1, j2 = torch.randint(x.size(1), (2,)) - if j1 < j2 - 1: + collision = x.new_zeros(x.size()) + result = [] + for _ in range(n): + while True: + i1, i2 = torch.randint(x.size(0), (2,)) + if i1 + min_height <= i2: + break + while True: + j1, j2 = torch.randint(x.size(1), (2,)) + if j1 + min_width <= j2: + break + collision[i1:i2, j1:j2] += 1 + if collision.max() > 1: + break + result.append((i1, j1, i2, j2)) + if collision.max() == 1: break - return i1, j1, i2, j2 + return result ###################################################################### def task_replace_color(self, A, f_A, B, f_B): - c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1 - for n, X, f_X in [(1, A, f_A), (1, B, f_B)]: - for _ in range(torch.randint(n, (1,)) + 1): - i1, j1, i2, j2 = self.rec_coo(X) - X[i1:i2, j1:j2] = c1 - f_X[i1:i2, j1:j2] = c2 + N = 3 + c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + r = self.rec_coo(X, N) + for n in range(N): + i1, j1, i2, j2 = r[n] + X[i1:i2, j1:j2] = c[n] + f_X[i1:i2, j1:j2] = c[n if n > 0 else -1] def task_move(self, A, f_A, B, f_B): di, dj = torch.randint(2, (2,)) * 2 - 1 - for n, X, f_X in [(1, A, f_A), (1, B, f_B)]: - c = torch.randperm(len(self.colors) - 1)[:1] + 1 - for _ in range(torch.randint(n, (1,)) + 1): - while True: - i1, j1, i2, j2 = self.rec_coo(X) - if ( - i1 + di >= 0 - and i2 + di < X.size(0) - and j1 + dj >= 0 - and j2 + dj < X.size(1) - ): - break - - X[i1:i2, j1:j2] = c - f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c + N = 3 + c = torch.randperm(len(self.colors) - 1)[:N] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + r = self.rec_coo(X, N) + i1, j1, i2, j2 = r[N - 1] + if ( + i1 + di >= 0 + and i2 + di < X.size(0) + and j1 + dj >= 0 + and j2 + dj < X.size(1) + ): + break + + for n in range(N): + i1, j1, i2, j2 = r[n] + X[i1:i2, j1:j2] = c[n] + if n == N - 1: + f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n] + else: + f_X[i1:i2, j1:j2] = c[n] def task_grow(self, A, f_A, B, f_B): + di, dj = torch.randint(2, (2,)) * 2 - 1 + N = 3 + c = torch.randperm(len(self.colors) - 1)[:N] + 1 direction = torch.randint(2, (1,)) - for n, X, f_X in [(1, A, f_A), (1, B, f_B)]: - c = torch.randperm(len(self.colors) - 1)[:1] + 1 - for _ in range(torch.randint(n, (1,)) + 1): - while True: - i1, j1, i2, j2 = self.rec_coo(X) - if i1 + 3 < i2 and j1 + 3 < j2: - break + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + r = self.rec_coo(X, N) + i1, j1, i2, j2 = r[N - 1] + if i1 + 3 < i2 and j1 + 3 < j2: + break + + for n in range(N): + i1, j1, i2, j2 = r[n] + if n == N - 1: + if direction == 0: + X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n] + f_X[i1:i2, j1:j2] = c[n] + else: + X[i1:i2, j1:j2] = c[n] + f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n] + else: + X[i1:i2, j1:j2] = c[n] + f_X[i1:i2, j1:j2] = c[n] - if direction == 0: - X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c - f_X[i1:i2, j1:j2] = c + def task_color_grow(self, A, f_A, B, f_B): + di, dj = torch.randint(2, (2,)) * 2 - 1 + N = 3 + c = torch.randperm(len(self.colors) - 1)[: 2 * N] + 1 + direction = torch.randint(2, (1,)) + for X, f_X in [(A, f_A), (B, f_B)]: + r = self.rec_coo(X, N) + for n in range(N): + i1, j1, i2, j2 = r[n] + X[i1 : (i1 + i2) // 2, j1:j2] = c[2 * n] + f_X[i1 : (i1 + i2) // 2, j1:j2] = c[2 * n] + X[(i1 + i2) // 2 : (i1 + i2) // 2 + 1, j1:j2] = c[2 * n + 1] + if n == N - 1: + f_X[(i1 + i2) // 2 : i2, j1:j2] = c[2 * n + 1] else: - X[i1:i2, j1:j2] = c - f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c + f_X[(i1 + i2) // 2 : (i1 + i2) // 2 + 1, j1:j2] = c[2 * n + 1] def task_frame(self, A, f_A, B, f_B): - direction = torch.randint(2, (1,)) - for n, X, f_X in [(1, A, f_A), (1, B, f_B)]: - c = torch.randperm(len(self.colors) - 1)[:1] + 1 - for _ in range(torch.randint(n, (1,)) + 1): - while True: - i1, j1, i2, j2 = self.rec_coo(X) - if i1 + 3 < i2 and j1 + 3 < j2: - break - X[i1:i2, j1:j2] = c - f_X[i1:i2, j1:j2] = c - f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0 + N = 3 + c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + r = self.rec_coo(X, N) + for n in range(N): + i1, j1, i2, j2 = r[n] + X[i1:i2, j1:j2] = c[n] + f_X[i1:i2, j1:j2] = c[n] + if n == N - 1: + f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0 ###################################################################### @@ -248,6 +290,7 @@ class Lang(problem.Problem): self.task_replace_color, self.task_move, self.task_grow, + self.task_color_grow, self.task_frame, ] prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64) -- 2.39.5