X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=reasoning.py;h=9462f87676687503576c030b1ac111c9a1fde0c3;hb=3ea12df1dcfc4006eb895fd62bb622e9aef6178c;hp=c545e97f0b6517d7bbec6d163a55b447c2e0a989;hpb=cb64cb72d5fff242cd90c13e4d8b9ca43e6e8837;p=culture.git diff --git a/reasoning.py b/reasoning.py index c545e97..9462f87 100755 --- a/reasoning.py +++ b/reasoning.py @@ -17,7 +17,7 @@ from torch.nn import functional as F import problem -class Reasoning(problem.Problem): +class Grids(problem.Problem): named_colors = [ ("white", [255, 255, 255]), ("red", [255, 0, 0]), @@ -87,9 +87,16 @@ class Reasoning(problem.Problem): answers, predicted_prompts=None, predicted_answers=None, + nrow=4, ): - prompts = prompts.reshape(prompts.size(0), self.height, -1) - answers = answers.reshape(answers.size(0), self.height, -1) + S = self.height * self.width + As = prompts[:, 0 * (S + 1) : 0 * (S + 1) + S].view(-1, self.height, self.width) + f_As = prompts[:, 1 * (S + 1) : 1 * (S + 1) + S].view( + -1, self.height, self.width + ) + Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S].view(-1, self.height, self.width) + prompts = torch.cat([As, f_As, Bs], dim=2) + answers = answers.reshape(answers.size(0), self.height, self.width) if predicted_prompts is None: predicted_prompts = 255 @@ -114,9 +121,13 @@ class Reasoning(problem.Problem): y[...] = c else: c = c.long()[:, None] - c = c * torch.tensor([192, 192, 192], device=c.device) + ( - 1 - c - ) * torch.tensor([255, 255, 255], device=c.device) + c = ( + (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long())) + * torch.tensor([64, 64, 64], device=c.device) + + (c == 1).long() * torch.tensor([0, 255, 0], device=c.device) + + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device) + + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device) + ) y[...] = c[:, :, None, None] y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x @@ -186,7 +197,11 @@ class Reasoning(problem.Problem): image_name = os.path.join(result_dir, filename) torchvision.utils.save_image( - img.float() / 255.0, image_name, nrow=4, padding=margin * 4, pad_value=1.0 + img.float() / 255.0, + image_name, + nrow=nrow, + padding=margin * 4, + pad_value=1.0, ) ###################################################################### @@ -406,53 +421,52 @@ class Reasoning(problem.Problem): if n < nb_rec - 1: f_X[i1, j1] = c[-1] + def contact(self, X, i, j, q): + nq, nq_diag = 0, 0 + no = 0 + + for ii, jj in [ + (i - 1, j - 1), + (i - 1, j), + (i - 1, j + 1), + (i, j - 1), + (i, j + 1), + (i + 1, j - 1), + (i + 1, j), + (i + 1, j + 1), + ]: + if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width: + if X[ii, jj] != 0 and X[ii, jj] != q: + no += 1 + + for ii, jj in [ + (i - 1, j - 1), + (i - 1, j + 1), + (i + 1, j - 1), + (i + 1, j + 1), + ]: + if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width: + if X[ii, jj] == q and X[i, jj] != q and X[ii, j] != q: + nq_diag += 1 + + for ii, jj in [(i - 1, j), (i, j - 1), (i, j + 1), (i + 1, j)]: + if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width: + if X[ii, jj] == q: + nq += 1 + + return no, nq, nq_diag + def task_count(self, A, f_A, B, f_B): N = torch.randint(4, (1,)) + 2 c = torch.randperm(len(self.colors) - 1)[:N] + 1 for X, f_X in [(A, f_A), (B, f_B)]: - - def contact(i, j, q): - nq, nq_diag = 0, 0 - no = 0 - - for ii, jj in [ - (i - 1, j - 1), - (i - 1, j), - (i - 1, j + 1), - (i, j - 1), - (i, j + 1), - (i + 1, j - 1), - (i + 1, j), - (i + 1, j + 1), - ]: - if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width: - if X[ii, jj] != 0 and X[ii, jj] != q: - no += 1 - - for ii, jj in [ - (i - 1, j - 1), - (i - 1, j + 1), - (i + 1, j - 1), - (i + 1, j + 1), - ]: - if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width: - if X[ii, jj] == q and X[i, jj] != q and X[ii, j] != q: - nq_diag += 1 - - for ii, jj in [(i - 1, j), (i, j - 1), (i, j + 1), (i + 1, j)]: - if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width: - if X[ii, jj] == q: - nq += 1 - - return no, nq, nq_diag - nb = torch.zeros(N, dtype=torch.int64) q = torch.randint(N, (self.height * self.width,)) k = torch.randperm(self.height * self.width) for p in range(self.height * self.width): i, j = k[p] % self.height, k[p] // self.height - no, nq, nq_diag = contact(i, j, c[q[p]]) + no, nq, nq_diag = self.contact(X, i, j, c[q[p]]) if no == 0 and nq_diag == 0: if nq == 0: if nb[q[p]] < self.width: @@ -555,10 +569,66 @@ class Reasoning(problem.Problem): if l > 3: break + def task_scale(self, A, f_A, B, f_B): + c = torch.randperm(len(self.colors) - 1)[:2] + 1 + + i, j = torch.randint(self.height // 2, (1,)), torch.randint( + self.width // 2, (1,) + ) + + for X, f_X in [(A, f_A), (B, f_B)]: + for _ in range(3): + while True: + i1, j1 = torch.randint(self.height // 2 + 1, (1,)), torch.randint( + self.width // 2 + 1, (1,) + ) + i2, j2 = torch.randint(self.height // 2 + 1, (1,)), torch.randint( + self.width // 2 + 1, (1,) + ) + if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3: + break + X[i + i1 : i + i2, j + j1 : j + j2] = c[0] + f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0] + + X[i, j] = c[1] + f_X[0:2, 0:2] = c[1] + + def task_islands(self, A, f_A, B, f_B): + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,)) + if ( + i == 0 + or i == self.height - 1 + or j == 0 + or j == self.width - 1 + or X[i, j] == 1 + ): + break + while True: + di, dj = torch.randint(3, (2,)) - 1 + if abs(di) + abs(dj) > 0: + break + X[i, j] = 1 + while True: + i, j = i + di, j + dj + if i < 0 or i >= self.height or j < 0 or j >= self.width: + break + b = ( + i == 0 + or i == self.height - 1 + or j == 0 + or j == self.width - 1 + or X[i, j] == 1 + ) + X[i, j] = 1 + if b: + break + ###################################################################### - def generate_prompts_and_answers(self, nb, device="cpu"): - tasks = [ + def all_tasks(self): + return [ self.task_replace_color, self.task_translate, self.task_grow, @@ -568,10 +638,17 @@ class Reasoning(problem.Problem): self.task_count, self.task_trajectory, self.task_bounce, + self.task_scale, + # self.task_islands, ] - prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64) - answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64) - w = self.width + + def generate_prompts_and_answers(self, nb, tasks=None, device="cpu"): + if tasks is None: + tasks = self.all_tasks() + + S = self.height * self.width + prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64) + answers = torch.zeros(nb, S, dtype=torch.int64) for prompt, answer in tqdm.tqdm( zip(prompts, answers), @@ -579,10 +656,10 @@ class Reasoning(problem.Problem): desc="world generation", total=prompts.size(0), ): - A = prompt[:, 0 * w : 1 * w] - f_A = prompt[:, 1 * w : 2 * w] - B = prompt[:, 2 * w : 3 * w] - f_B = answer + A = prompt[0 * (S + 1) : 0 * (S + 1) + S].view(self.height, self.width) + f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width) + B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width) + f_B = answer.view(self.height, self.width) task = tasks[torch.randint(len(tasks), (1,))] task(A, f_A, B, f_B) @@ -596,6 +673,7 @@ class Reasoning(problem.Problem): answers, predicted_prompts=None, predicted_answers=None, + nrow=4, ): self.save_image( result_dir, @@ -604,6 +682,7 @@ class Reasoning(problem.Problem): answers, predicted_prompts, predicted_answers, + nrow, ) @@ -612,22 +691,35 @@ class Reasoning(problem.Problem): if __name__ == "__main__": import time - reasoning = Reasoning() + nb = 48 + + grids = Grids() + + for t in grids.all_tasks(): + # for t in [grids.task_islands]: + print(t.__name__) + prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t]) + grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4) + + exit(0) + + nb = 72 start_time = time.perf_counter() - prompts, answers = reasoning.generate_prompts_and_answers(100) + prompts, answers = grids.generate_prompts_and_answers(nb) delay = time.perf_counter() - start_time print(f"{prompts.size(0)/delay:02f} seq/s") - predicted_prompts = torch.rand(prompts.size(0)) < 0.5 - predicted_answers = torch.logical_not(predicted_prompts) + m = torch.randint(2, (prompts.size(0),)) + predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1) + predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1) - reasoning.save_quizzes( + grids.save_quizzes( "/tmp", "test", - prompts[:64], - answers[:64], + prompts[:nb], + answers[:nb], # You can add a bool to put a frame around the predicted parts - # predicted_prompts[:64], - # predicted_answers[:64], + predicted_prompts[:nb], + predicted_answers[:nb], )