X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=reasoning.py;h=9462f87676687503576c030b1ac111c9a1fde0c3;hb=3ea12df1dcfc4006eb895fd62bb622e9aef6178c;hp=2874adc9c1c3f213bafce0c8dac53063c6463ab9;hpb=dfa00f17ae94d481ca0f8fd6ce96dbcaa4bbe06e;p=culture.git diff --git a/reasoning.py b/reasoning.py index 2874adc..9462f87 100755 --- a/reasoning.py +++ b/reasoning.py @@ -17,7 +17,7 @@ from torch.nn import functional as F import problem -class Reasoning(problem.Problem): +class Grids(problem.Problem): named_colors = [ ("white", [255, 255, 255]), ("red", [255, 0, 0]), @@ -87,9 +87,16 @@ class Reasoning(problem.Problem): answers, predicted_prompts=None, predicted_answers=None, + nrow=4, ): - prompts = prompts.reshape(prompts.size(0), self.height, -1) - answers = answers.reshape(answers.size(0), self.height, -1) + S = self.height * self.width + As = prompts[:, 0 * (S + 1) : 0 * (S + 1) + S].view(-1, self.height, self.width) + f_As = prompts[:, 1 * (S + 1) : 1 * (S + 1) + S].view( + -1, self.height, self.width + ) + Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S].view(-1, self.height, self.width) + prompts = torch.cat([As, f_As, Bs], dim=2) + answers = answers.reshape(answers.size(0), self.height, self.width) if predicted_prompts is None: predicted_prompts = 255 @@ -114,9 +121,13 @@ class Reasoning(problem.Problem): y[...] = c else: c = c.long()[:, None] - c = c * torch.tensor([192, 192, 192], device=c.device) + ( - 1 - c - ) * torch.tensor([255, 255, 255], device=c.device) + c = ( + (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long())) + * torch.tensor([64, 64, 64], device=c.device) + + (c == 1).long() * torch.tensor([0, 255, 0], device=c.device) + + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device) + + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device) + ) y[...] = c[:, :, None, None] y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x @@ -186,7 +197,11 @@ class Reasoning(problem.Problem): image_name = os.path.join(result_dir, filename) torchvision.utils.save_image( - img.float() / 255.0, image_name, nrow=4, padding=margin * 4, pad_value=1.0 + img.float() / 255.0, + image_name, + nrow=nrow, + padding=margin * 4, + pad_value=1.0, ) ###################################################################### @@ -293,7 +308,7 @@ class Reasoning(problem.Problem): X[i1:i2, j1:j2] = c[n] f_X[i1:i2, j1:j2] = c[n if n > 0 else -1] - def task_move(self, A, f_A, B, f_B): + def task_translate(self, A, f_A, B, f_B): di, dj = torch.randint(3, (2,)) - 1 nb_rec = 3 c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 @@ -406,20 +421,234 @@ class Reasoning(problem.Problem): if n < nb_rec - 1: f_X[i1, j1] = c[-1] + def contact(self, X, i, j, q): + nq, nq_diag = 0, 0 + no = 0 + + for ii, jj in [ + (i - 1, j - 1), + (i - 1, j), + (i - 1, j + 1), + (i, j - 1), + (i, j + 1), + (i + 1, j - 1), + (i + 1, j), + (i + 1, j + 1), + ]: + if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width: + if X[ii, jj] != 0 and X[ii, jj] != q: + no += 1 + + for ii, jj in [ + (i - 1, j - 1), + (i - 1, j + 1), + (i + 1, j - 1), + (i + 1, j + 1), + ]: + if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width: + if X[ii, jj] == q and X[i, jj] != q and X[ii, j] != q: + nq_diag += 1 + + for ii, jj in [(i - 1, j), (i, j - 1), (i, j + 1), (i + 1, j)]: + if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width: + if X[ii, jj] == q: + nq += 1 + + return no, nq, nq_diag + + def task_count(self, A, f_A, B, f_B): + N = torch.randint(4, (1,)) + 2 + c = torch.randperm(len(self.colors) - 1)[:N] + 1 + + for X, f_X in [(A, f_A), (B, f_B)]: + nb = torch.zeros(N, dtype=torch.int64) + q = torch.randint(N, (self.height * self.width,)) + k = torch.randperm(self.height * self.width) + for p in range(self.height * self.width): + i, j = k[p] % self.height, k[p] // self.height + no, nq, nq_diag = self.contact(X, i, j, c[q[p]]) + if no == 0 and nq_diag == 0: + if nq == 0: + if nb[q[p]] < self.width: + X[i, j] = c[q[p]] + nb[q[p]] += 1 + if nq == 1: + X[i, j] = c[q[p]] + + for n in range(N): + for j in range(nb[n]): + f_X[n, j] = c[n] + + def task_trajectory(self, A, f_A, B, f_B): + c = torch.randperm(len(self.colors) - 1)[:2] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + di, dj = torch.randint(7, (2,)) - 3 + i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,)) + if ( + abs(di) + abs(dj) > 0 + and i + 2 * di >= 0 + and i + 2 * di < self.height + and j + 2 * dj >= 0 + and j + 2 * dj < self.width + ): + break + + k = 0 + while ( + i + k * di >= 0 + and i + k * di < self.height + and j + k * dj >= 0 + and j + k * dj < self.width + ): + if k < 2: + X[i + k * di, j + k * dj] = c[k] + f_X[i + k * di, j + k * dj] = c[min(k, 1)] + k += 1 + + def task_bounce(self, A, f_A, B, f_B): + c = torch.randperm(len(self.colors) - 1)[:3] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + + def free(i, j): + return ( + i >= 0 + and i < self.height + and j >= 0 + and j < self.width + and f_X[i, j] == 0 + ) + + while True: + f_X[...] = 0 + X[...] = 0 + + for _ in range((self.height * self.width) // 10): + i, j = torch.randint(self.height, (1,)), torch.randint( + self.width, (1,) + ) + X[i, j] = c[0] + f_X[i, j] = c[0] + + while True: + di, dj = torch.randint(7, (2,)) - 3 + if abs(di) + abs(dj) == 1: + break + + i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,)) + + X[i, j] = c[1] + f_X[i, j] = c[1] + l = 0 + + while True: + l += 1 + if free(i + di, j + dj): + pass + elif free(i - dj, j + di): + di, dj = -dj, di + if free(i + dj, j - di): + if torch.rand(1) < 0.5: + di, dj = -di, -dj + elif free(i + dj, j - di): + di, dj = dj, -di + else: + break + + i, j = i + di, j + dj + f_X[i, j] = c[2] + if l <= 1: + X[i, j] = c[2] + + if l >= self.width: + break + + f_X[i, j] = c[1] + X[i, j] = c[1] + + if l > 3: + break + + def task_scale(self, A, f_A, B, f_B): + c = torch.randperm(len(self.colors) - 1)[:2] + 1 + + i, j = torch.randint(self.height // 2, (1,)), torch.randint( + self.width // 2, (1,) + ) + + for X, f_X in [(A, f_A), (B, f_B)]: + for _ in range(3): + while True: + i1, j1 = torch.randint(self.height // 2 + 1, (1,)), torch.randint( + self.width // 2 + 1, (1,) + ) + i2, j2 = torch.randint(self.height // 2 + 1, (1,)), torch.randint( + self.width // 2 + 1, (1,) + ) + if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3: + break + X[i + i1 : i + i2, j + j1 : j + j2] = c[0] + f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0] + + X[i, j] = c[1] + f_X[0:2, 0:2] = c[1] + + def task_islands(self, A, f_A, B, f_B): + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,)) + if ( + i == 0 + or i == self.height - 1 + or j == 0 + or j == self.width - 1 + or X[i, j] == 1 + ): + break + while True: + di, dj = torch.randint(3, (2,)) - 1 + if abs(di) + abs(dj) > 0: + break + X[i, j] = 1 + while True: + i, j = i + di, j + dj + if i < 0 or i >= self.height or j < 0 or j >= self.width: + break + b = ( + i == 0 + or i == self.height - 1 + or j == 0 + or j == self.width - 1 + or X[i, j] == 1 + ) + X[i, j] = 1 + if b: + break + ###################################################################### - def generate_prompts_and_answers(self, nb, device="cpu"): - tasks = [ + def all_tasks(self): + return [ self.task_replace_color, - self.task_move, + self.task_translate, self.task_grow, self.task_color_grow, self.task_frame, self.task_detect, + self.task_count, + self.task_trajectory, + self.task_bounce, + self.task_scale, + # self.task_islands, ] - prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64) - answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64) - w = self.width + + def generate_prompts_and_answers(self, nb, tasks=None, device="cpu"): + if tasks is None: + tasks = self.all_tasks() + + S = self.height * self.width + prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64) + answers = torch.zeros(nb, S, dtype=torch.int64) for prompt, answer in tqdm.tqdm( zip(prompts, answers), @@ -427,10 +656,10 @@ class Reasoning(problem.Problem): desc="world generation", total=prompts.size(0), ): - A = prompt[:, 0 * w : 1 * w] - f_A = prompt[:, 1 * w : 2 * w] - B = prompt[:, 2 * w : 3 * w] - f_B = answer + A = prompt[0 * (S + 1) : 0 * (S + 1) + S].view(self.height, self.width) + f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width) + B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width) + f_B = answer.view(self.height, self.width) task = tasks[torch.randint(len(tasks), (1,))] task(A, f_A, B, f_B) @@ -444,6 +673,7 @@ class Reasoning(problem.Problem): answers, predicted_prompts=None, predicted_answers=None, + nrow=4, ): self.save_image( result_dir, @@ -452,6 +682,7 @@ class Reasoning(problem.Problem): answers, predicted_prompts, predicted_answers, + nrow, ) @@ -460,22 +691,35 @@ class Reasoning(problem.Problem): if __name__ == "__main__": import time - reasoning = Reasoning() + nb = 48 + + grids = Grids() + + for t in grids.all_tasks(): + # for t in [grids.task_islands]: + print(t.__name__) + prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t]) + grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4) + + exit(0) + + nb = 72 start_time = time.perf_counter() - prompts, answers = reasoning.generate_prompts_and_answers(100) + prompts, answers = grids.generate_prompts_and_answers(nb) delay = time.perf_counter() - start_time print(f"{prompts.size(0)/delay:02f} seq/s") - predicted_prompts = torch.rand(prompts.size(0)) < 0.5 - predicted_answers = torch.logical_not(predicted_prompts) + m = torch.randint(2, (prompts.size(0),)) + predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1) + predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1) - reasoning.save_quizzes( + grids.save_quizzes( "/tmp", "test", - prompts[:64], - answers[:64], + prompts[:nb], + answers[:nb], # You can add a bool to put a frame around the predicted parts - predicted_prompts[:64], - predicted_answers[:64], + predicted_prompts[:nb], + predicted_answers[:nb], )