X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=reasoning.py;h=5499bdfd6985081a2f8db5bf1ace853fcc354768;hb=30c76210e3ed2704b2a059208f385cb623c1486d;hp=c4429471bbe907b85de91a9fb9c3c01160d8b2c3;hpb=167c56ace610c3b975c702203bb7c7ddf74930ae;p=culture.git diff --git a/reasoning.py b/reasoning.py index c442947..5499bdf 100755 --- a/reasoning.py +++ b/reasoning.py @@ -27,9 +27,9 @@ class Reasoning(problem.Problem): ("cyan", [0, 255, 255]), ("violet", [255, 0, 255]), ("lightgreen", [192, 255, 192]), - ("pink", [255, 192, 192]), + ("brown", [165, 42, 42]), ("lightblue", [192, 192, 255]), - ("gray", [192, 192, 192]), + ("gray", [128, 128, 128]), ] def __init__(self, device=torch.device("cpu")): @@ -42,6 +42,31 @@ class Reasoning(problem.Problem): ###################################################################### def frame2img(self, x, scale=15): + x = x.reshape(x.size(0), self.height, -1) + m = torch.logical_and(x >= 0, x < self.nb_token_values()).long() + x = self.colors[x * m].permute(0, 3, 1, 2) + s = x.shape + x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale) + x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale) + + x[:, :, :, torch.arange(0, x.size(3), scale)] = 0 + x[:, :, torch.arange(0, x.size(2), scale), :] = 0 + x = x[:, :, 1:, 1:] + + for n in range(m.size(0)): + for i in range(m.size(1)): + for j in range(m.size(2)): + if m[n, i, j] == 0: + for k in range(2, scale - 2): + for l in [0, 1]: + x[n, :, i * scale + k, j * scale + k - l] = 0 + x[ + n, :, i * scale + scale - 1 - k, j * scale + k - l + ] = 0 + + return x + + def frame2img_(self, x, scale=15): x = x.reshape(x.size(0), self.height, -1) x = self.colors[x].permute(0, 3, 1, 2) s = x.shape @@ -62,6 +87,7 @@ class Reasoning(problem.Problem): answers, predicted_prompts=None, predicted_answers=None, + nrow=4, ): prompts = prompts.reshape(prompts.size(0), self.height, -1) answers = answers.reshape(answers.size(0), self.height, -1) @@ -89,9 +115,13 @@ class Reasoning(problem.Problem): y[...] = c else: c = c.long()[:, None] - c = c * torch.tensor([192, 192, 192], device=c.device) + ( - 1 - c - ) * torch.tensor([255, 255, 255], device=c.device) + c = ( + (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long())) + * torch.tensor([192, 192, 192], device=c.device) + + (c == 1).long() * torch.tensor([0, 255, 0], device=c.device) + + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device) + + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device) + ) y[...] = c[:, :, None, None] y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x @@ -161,7 +191,11 @@ class Reasoning(problem.Problem): image_name = os.path.join(result_dir, filename) torchvision.utils.save_image( - img.float() / 255.0, image_name, nrow=4, padding=margin * 4, pad_value=1.0 + img.float() / 255.0, + image_name, + nrow=nrow, + padding=margin * 4, + pad_value=1.0, ) ###################################################################### @@ -171,16 +205,15 @@ class Reasoning(problem.Problem): # That's quite a tensorial spaghetti mess to sample # non-overlapping rectangles quickly, but made the generation of - # 100k samples from 1h50 with a lame pure python code to 4min with - # this one. - def rec_coo(self, x, n, min_height=3, min_width=3): - K = 3 - N = 1000 + # 100k samples go from 1h50 with a lame pure python code to 3min30s + # with this one. + def rec_coo(self, nb_rec, min_height=3, min_width=3): + nb_trials = 200 while True: v = ( ( - torch.rand(N * K, self.height + 1, device=self.device) + torch.rand(nb_trials * nb_rec, self.height + 1, device=self.device) .sort(dim=-1) .indices < 2 @@ -192,7 +225,7 @@ class Reasoning(problem.Problem): h = ( ( - torch.rand(N * K, self.width + 1, device=self.device) + torch.rand(nb_trials * nb_rec, self.width + 1, device=self.device) .sort(dim=-1) .indices < 2 @@ -207,10 +240,10 @@ class Reasoning(problem.Problem): ) v, h = v[i], h[i] - v = v[: v.size(0) - v.size(0) % K] - h = h[: h.size(0) - h.size(0) % K] - v = v.reshape(v.size(0) // K, K, -1) - h = h.reshape(h.size(0) // K, K, -1) + v = v[: v.size(0) - v.size(0) % nb_rec] + h = h[: h.size(0) - h.size(0) % nb_rec] + v = v.reshape(v.size(0) // nb_rec, nb_rec, -1) + h = h.reshape(h.size(0) // nb_rec, nb_rec, -1) r = v[:, :, :, None] * h[:, :, None, :] @@ -260,23 +293,23 @@ class Reasoning(problem.Problem): ###################################################################### def task_replace_color(self, A, f_A, B, f_B): - N = 3 - c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1 + nb_rec = 3 + c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 for X, f_X in [(A, f_A), (B, f_B)]: - r = self.rec_coo(X, N) - for n in range(N): + r = self.rec_coo(nb_rec) + for n in range(nb_rec): i1, j1, i2, j2 = r[n] X[i1:i2, j1:j2] = c[n] f_X[i1:i2, j1:j2] = c[n if n > 0 else -1] - def task_move(self, A, f_A, B, f_B): - di, dj = torch.randint(2, (2,)) * 2 - 1 - N = 3 - c = torch.randperm(len(self.colors) - 1)[:N] + 1 + def task_translate(self, A, f_A, B, f_B): + di, dj = torch.randint(3, (2,)) - 1 + nb_rec = 3 + c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 for X, f_X in [(A, f_A), (B, f_B)]: while True: - r = self.rec_coo(X, N) - i1, j1, i2, j2 = r[N - 1] + r = self.rec_coo(nb_rec) + i1, j1, i2, j2 = r[nb_rec - 1] if ( i1 + di >= 0 and i2 + di < X.size(0) @@ -285,29 +318,29 @@ class Reasoning(problem.Problem): ): break - for n in range(N): + for n in range(nb_rec): i1, j1, i2, j2 = r[n] X[i1:i2, j1:j2] = c[n] - if n == N - 1: + if n == nb_rec - 1: f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n] else: f_X[i1:i2, j1:j2] = c[n] def task_grow(self, A, f_A, B, f_B): di, dj = torch.randint(2, (2,)) * 2 - 1 - N = 3 - c = torch.randperm(len(self.colors) - 1)[:N] + 1 + nb_rec = 3 + c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 direction = torch.randint(2, (1,)) for X, f_X in [(A, f_A), (B, f_B)]: while True: - r = self.rec_coo(X, N) - i1, j1, i2, j2 = r[N - 1] + r = self.rec_coo(nb_rec) + i1, j1, i2, j2 = r[nb_rec - 1] if i1 + 3 < i2 and j1 + 3 < j2: break - for n in range(N): + for n in range(nb_rec): i1, j1, i2, j2 = r[n] - if n == N - 1: + if n == nb_rec - 1: if direction == 0: X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n] f_X[i1:i2, j1:j2] = c[n] @@ -320,55 +353,261 @@ class Reasoning(problem.Problem): def task_color_grow(self, A, f_A, B, f_B): di, dj = torch.randint(2, (2,)) * 2 - 1 - N = 3 - c = torch.randperm(len(self.colors) - 1)[: 2 * N] + 1 - direction = torch.randint(2, (1,)) + nb_rec = 3 + c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1 + direction = torch.randint(4, (1,)) for X, f_X in [(A, f_A), (B, f_B)]: - r = self.rec_coo(X, N) - for n in range(N): + r = self.rec_coo(nb_rec) + for n in range(nb_rec): i1, j1, i2, j2 = r[n] - i = (i1 + i2) // 2 X[i1:i2, j1:j2] = c[2 * n] - X[i : i + 1, j1:j2] = c[2 * n + 1] f_X[i1:i2, j1:j2] = c[2 * n] - if n == N - 1: - f_X[i:i2, j1:j2] = c[2 * n + 1] - else: - f_X[i : i + 1, j1:j2] = c[2 * n + 1] + # Not my proudest moment + if direction == 0: + i = (i1 + i2) // 2 + X[i : i + 1, j1:j2] = c[2 * n + 1] + if n == nb_rec - 1: + f_X[i:i2, j1:j2] = c[2 * n + 1] + else: + f_X[i : i + 1, j1:j2] = c[2 * n + 1] + elif direction == 1: + i = (i1 + i2 - 1) // 2 + X[i : i + 1, j1:j2] = c[2 * n + 1] + if n == nb_rec - 1: + f_X[i1 : i + 1, j1:j2] = c[2 * n + 1] + else: + f_X[i : i + 1, j1:j2] = c[2 * n + 1] + elif direction == 2: + j = (j1 + j2) // 2 + X[i1:i2, j : j + 1] = c[2 * n + 1] + if n == nb_rec - 1: + f_X[i1:i2, j:j2] = c[2 * n + 1] + else: + f_X[i1:i2, j : j + 1] = c[2 * n + 1] + elif direction == 3: + j = (j1 + j2 - 1) // 2 + X[i1:i2, j : j + 1] = c[2 * n + 1] + if n == nb_rec - 1: + f_X[i1:i2, j1 : j + 1] = c[2 * n + 1] + else: + f_X[i1:i2, j : j + 1] = c[2 * n + 1] def task_frame(self, A, f_A, B, f_B): - N = 3 - c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1 + nb_rec = 3 + c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 for X, f_X in [(A, f_A), (B, f_B)]: - r = self.rec_coo(X, N) - for n in range(N): + r = self.rec_coo(nb_rec) + for n in range(nb_rec): i1, j1, i2, j2 = r[n] X[i1:i2, j1:j2] = c[n] f_X[i1:i2, j1:j2] = c[n] - if n == N - 1: + if n == nb_rec - 1: f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0 def task_detect(self, A, f_A, B, f_B): - N = 3 - c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1 + nb_rec = 3 + c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 for X, f_X in [(A, f_A), (B, f_B)]: - r = self.rec_coo(X, N) - for n in range(N): + r = self.rec_coo(nb_rec) + for n in range(nb_rec): i1, j1, i2, j2 = r[n] X[i1:i2, j1:j2] = c[n] - f_X[i1, j1] = c[-1] + if n < nb_rec - 1: + f_X[i1, j1] = c[-1] + + def task_count(self, A, f_A, B, f_B): + N = torch.randint(4, (1,)) + 2 + c = torch.randperm(len(self.colors) - 1)[:N] + 1 + + for X, f_X in [(A, f_A), (B, f_B)]: + + def contact(i, j, q): + nq, nq_diag = 0, 0 + no = 0 + + for ii, jj in [ + (i - 1, j - 1), + (i - 1, j), + (i - 1, j + 1), + (i, j - 1), + (i, j + 1), + (i + 1, j - 1), + (i + 1, j), + (i + 1, j + 1), + ]: + if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width: + if X[ii, jj] != 0 and X[ii, jj] != q: + no += 1 + + for ii, jj in [ + (i - 1, j - 1), + (i - 1, j + 1), + (i + 1, j - 1), + (i + 1, j + 1), + ]: + if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width: + if X[ii, jj] == q and X[i, jj] != q and X[ii, j] != q: + nq_diag += 1 + + for ii, jj in [(i - 1, j), (i, j - 1), (i, j + 1), (i + 1, j)]: + if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width: + if X[ii, jj] == q: + nq += 1 + + return no, nq, nq_diag + + nb = torch.zeros(N, dtype=torch.int64) + q = torch.randint(N, (self.height * self.width,)) + k = torch.randperm(self.height * self.width) + for p in range(self.height * self.width): + i, j = k[p] % self.height, k[p] // self.height + no, nq, nq_diag = contact(i, j, c[q[p]]) + if no == 0 and nq_diag == 0: + if nq == 0: + if nb[q[p]] < self.width: + X[i, j] = c[q[p]] + nb[q[p]] += 1 + if nq == 1: + X[i, j] = c[q[p]] + + for n in range(N): + for j in range(nb[n]): + f_X[n, j] = c[n] + + def task_trajectory(self, A, f_A, B, f_B): + c = torch.randperm(len(self.colors) - 1)[:2] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + di, dj = torch.randint(7, (2,)) - 3 + i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,)) + if ( + abs(di) + abs(dj) > 0 + and i + 2 * di >= 0 + and i + 2 * di < self.height + and j + 2 * dj >= 0 + and j + 2 * dj < self.width + ): + break + + k = 0 + while ( + i + k * di >= 0 + and i + k * di < self.height + and j + k * dj >= 0 + and j + k * dj < self.width + ): + if k < 2: + X[i + k * di, j + k * dj] = c[k] + f_X[i + k * di, j + k * dj] = c[min(k, 1)] + k += 1 + + def task_bounce(self, A, f_A, B, f_B): + c = torch.randperm(len(self.colors) - 1)[:3] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + + def free(i, j): + return ( + i >= 0 + and i < self.height + and j >= 0 + and j < self.width + and f_X[i, j] == 0 + ) + + while True: + f_X[...] = 0 + X[...] = 0 + + for _ in range((self.height * self.width) // 10): + i, j = torch.randint(self.height, (1,)), torch.randint( + self.width, (1,) + ) + X[i, j] = c[0] + f_X[i, j] = c[0] + + while True: + di, dj = torch.randint(7, (2,)) - 3 + if abs(di) + abs(dj) == 1: + break + + i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,)) + + X[i, j] = c[1] + f_X[i, j] = c[1] + l = 0 + + while True: + l += 1 + if free(i + di, j + dj): + pass + elif free(i - dj, j + di): + di, dj = -dj, di + if free(i + dj, j - di): + if torch.rand(1) < 0.5: + di, dj = -di, -dj + elif free(i + dj, j - di): + di, dj = dj, -di + else: + break + + i, j = i + di, j + dj + f_X[i, j] = c[2] + if l <= 1: + X[i, j] = c[2] + + if l >= self.width: + break + + f_X[i, j] = c[1] + X[i, j] = c[1] + + if l > 3: + break + + def task_scale(self, A, f_A, B, f_B): + c = torch.randperm(len(self.colors) - 1)[:2] + 1 + + i, j = torch.randint(self.height // 2, (1,)), torch.randint( + self.width // 2, (1,) + ) + + for X, f_X in [(A, f_A), (B, f_B)]: + for _ in range(3): + while True: + i1, j1 = torch.randint(self.height // 2 + 1, (1,)), torch.randint( + self.width // 2 + 1, (1,) + ) + i2, j2 = torch.randint(self.height // 2 + 1, (1,)), torch.randint( + self.width // 2 + 1, (1,) + ) + if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3: + break + X[i + i1 : i + i2, j + j1 : j + j2] = c[0] + f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0] + + X[i, j] = c[1] + f_X[0:2, 0:2] = c[1] ###################################################################### - def generate_prompts_and_answers(self, nb, device="cpu"): - tasks = [ + def all_tasks(self): + return [ self.task_replace_color, - self.task_move, + self.task_translate, self.task_grow, self.task_color_grow, self.task_frame, self.task_detect, + self.task_count, + self.task_trajectory, + self.task_bounce, + self.task_scale, ] + + def generate_prompts_and_answers(self, nb, tasks=None, device="cpu"): + if tasks is None: + tasks = self.all_tasks() + prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64) answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64) w = self.width @@ -396,6 +635,7 @@ class Reasoning(problem.Problem): answers, predicted_prompts=None, predicted_answers=None, + nrow=4, ): self.save_image( result_dir, @@ -404,6 +644,7 @@ class Reasoning(problem.Problem): answers, predicted_prompts, predicted_answers, + nrow, ) @@ -412,21 +653,32 @@ class Reasoning(problem.Problem): if __name__ == "__main__": import time + nb = 4 + reasoning = Reasoning() + for t in reasoning.all_tasks(): + print(t.__name__) + prompts, answers = reasoning.generate_prompts_and_answers(nb, tasks=[t]) + reasoning.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=1) + + exit(0) + start_time = time.perf_counter() - prompts, answers = reasoning.generate_prompts_and_answers(100) + prompts, answers = reasoning.generate_prompts_and_answers(nb) delay = time.perf_counter() - start_time print(f"{prompts.size(0)/delay:02f} seq/s") - # predicted_prompts = torch.rand(prompts.size(0)) < 0.5 - # predicted_answers = torch.logical_not(predicted_prompts) + # m = torch.randint(2, (prompts.size(0),)) + # predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1) + # predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1) reasoning.save_quizzes( "/tmp", "test", - prompts[:36], - answers[:36], + prompts[:nb], + answers[:nb], # You can add a bool to put a frame around the predicted parts - # predicted_prompts, predicted_answers + # predicted_prompts[:nb], + # predicted_answers[:nb], )