X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=reasoning.py;h=5499bdfd6985081a2f8db5bf1ace853fcc354768;hb=30c76210e3ed2704b2a059208f385cb623c1486d;hp=9e26d64e34d1d796187d4f2fb9ad0f055562472c;hpb=5a0c2432316b0a413f1769ab429d33433a94e6e1;p=culture.git diff --git a/reasoning.py b/reasoning.py index 9e26d64..5499bdf 100755 --- a/reasoning.py +++ b/reasoning.py @@ -87,6 +87,7 @@ class Reasoning(problem.Problem): answers, predicted_prompts=None, predicted_answers=None, + nrow=4, ): prompts = prompts.reshape(prompts.size(0), self.height, -1) answers = answers.reshape(answers.size(0), self.height, -1) @@ -114,9 +115,13 @@ class Reasoning(problem.Problem): y[...] = c else: c = c.long()[:, None] - c = c * torch.tensor([192, 192, 192], device=c.device) + ( - 1 - c - ) * torch.tensor([255, 255, 255], device=c.device) + c = ( + (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long())) + * torch.tensor([192, 192, 192], device=c.device) + + (c == 1).long() * torch.tensor([0, 255, 0], device=c.device) + + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device) + + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device) + ) y[...] = c[:, :, None, None] y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x @@ -186,7 +191,11 @@ class Reasoning(problem.Problem): image_name = os.path.join(result_dir, filename) torchvision.utils.save_image( - img.float() / 255.0, image_name, nrow=4, padding=margin * 4, pad_value=1.0 + img.float() / 255.0, + image_name, + nrow=nrow, + padding=margin * 4, + pad_value=1.0, ) ###################################################################### @@ -465,24 +474,124 @@ class Reasoning(problem.Problem): for j in range(nb[n]): f_X[n, j] = c[n] - def task_count_(self, A, f_A, B, f_B): - N = torch.randint(3, (1,)) + 1 - c = torch.randperm(len(self.colors) - 1)[:N] + 1 + def task_trajectory(self, A, f_A, B, f_B): + c = torch.randperm(len(self.colors) - 1)[:2] + 1 for X, f_X in [(A, f_A), (B, f_B)]: - nb = torch.randint(self.width, (3,)) + 1 - k = torch.randperm(self.height * self.width)[: nb.sum()] - p = 0 - for n in range(N): - for m in range(nb[n]): - i, j = k[p] % self.height, k[p] // self.height - X[i, j] = c[n] - f_X[n, m] = c[n] - p += 1 + while True: + di, dj = torch.randint(7, (2,)) - 3 + i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,)) + if ( + abs(di) + abs(dj) > 0 + and i + 2 * di >= 0 + and i + 2 * di < self.height + and j + 2 * dj >= 0 + and j + 2 * dj < self.width + ): + break + + k = 0 + while ( + i + k * di >= 0 + and i + k * di < self.height + and j + k * dj >= 0 + and j + k * dj < self.width + ): + if k < 2: + X[i + k * di, j + k * dj] = c[k] + f_X[i + k * di, j + k * dj] = c[min(k, 1)] + k += 1 + + def task_bounce(self, A, f_A, B, f_B): + c = torch.randperm(len(self.colors) - 1)[:3] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + + def free(i, j): + return ( + i >= 0 + and i < self.height + and j >= 0 + and j < self.width + and f_X[i, j] == 0 + ) + + while True: + f_X[...] = 0 + X[...] = 0 + + for _ in range((self.height * self.width) // 10): + i, j = torch.randint(self.height, (1,)), torch.randint( + self.width, (1,) + ) + X[i, j] = c[0] + f_X[i, j] = c[0] + + while True: + di, dj = torch.randint(7, (2,)) - 3 + if abs(di) + abs(dj) == 1: + break + + i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,)) + + X[i, j] = c[1] + f_X[i, j] = c[1] + l = 0 + + while True: + l += 1 + if free(i + di, j + dj): + pass + elif free(i - dj, j + di): + di, dj = -dj, di + if free(i + dj, j - di): + if torch.rand(1) < 0.5: + di, dj = -di, -dj + elif free(i + dj, j - di): + di, dj = dj, -di + else: + break + + i, j = i + di, j + dj + f_X[i, j] = c[2] + if l <= 1: + X[i, j] = c[2] + + if l >= self.width: + break + + f_X[i, j] = c[1] + X[i, j] = c[1] + + if l > 3: + break + + def task_scale(self, A, f_A, B, f_B): + c = torch.randperm(len(self.colors) - 1)[:2] + 1 + + i, j = torch.randint(self.height // 2, (1,)), torch.randint( + self.width // 2, (1,) + ) + + for X, f_X in [(A, f_A), (B, f_B)]: + for _ in range(3): + while True: + i1, j1 = torch.randint(self.height // 2 + 1, (1,)), torch.randint( + self.width // 2 + 1, (1,) + ) + i2, j2 = torch.randint(self.height // 2 + 1, (1,)), torch.randint( + self.width // 2 + 1, (1,) + ) + if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3: + break + X[i + i1 : i + i2, j + j1 : j + j2] = c[0] + f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0] + + X[i, j] = c[1] + f_X[0:2, 0:2] = c[1] ###################################################################### - def generate_prompts_and_answers(self, nb, device="cpu"): - tasks = [ + def all_tasks(self): + return [ self.task_replace_color, self.task_translate, self.task_grow, @@ -490,7 +599,15 @@ class Reasoning(problem.Problem): self.task_frame, self.task_detect, self.task_count, + self.task_trajectory, + self.task_bounce, + self.task_scale, ] + + def generate_prompts_and_answers(self, nb, tasks=None, device="cpu"): + if tasks is None: + tasks = self.all_tasks() + prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64) answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64) w = self.width @@ -518,6 +635,7 @@ class Reasoning(problem.Problem): answers, predicted_prompts=None, predicted_answers=None, + nrow=4, ): self.save_image( result_dir, @@ -526,6 +644,7 @@ class Reasoning(problem.Problem): answers, predicted_prompts, predicted_answers, + nrow, ) @@ -534,22 +653,32 @@ class Reasoning(problem.Problem): if __name__ == "__main__": import time + nb = 4 + reasoning = Reasoning() + for t in reasoning.all_tasks(): + print(t.__name__) + prompts, answers = reasoning.generate_prompts_and_answers(nb, tasks=[t]) + reasoning.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=1) + + exit(0) + start_time = time.perf_counter() - prompts, answers = reasoning.generate_prompts_and_answers(100) + prompts, answers = reasoning.generate_prompts_and_answers(nb) delay = time.perf_counter() - start_time print(f"{prompts.size(0)/delay:02f} seq/s") - predicted_prompts = torch.rand(prompts.size(0)) < 0.5 - predicted_answers = torch.logical_not(predicted_prompts) + # m = torch.randint(2, (prompts.size(0),)) + # predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1) + # predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1) reasoning.save_quizzes( "/tmp", "test", - prompts[:64], - answers[:64], + prompts[:nb], + answers[:nb], # You can add a bool to put a frame around the predicted parts - # predicted_prompts[:64], - # predicted_answers[:64], + # predicted_prompts[:nb], + # predicted_answers[:nb], )