From 1fb70f7c2d14ad7a1f600cdf472d179d84c438bf Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 3 Jul 2024 22:59:59 +0300 Subject: [PATCH] Update. --- main.py | 2 +- reasoning.py | 77 ++++++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 72 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index 8b5d9a4..a954af6 100755 --- a/main.py +++ b/main.py @@ -250,7 +250,7 @@ if args.problem == "sky": speed=args.sky_speed, ) elif args.problem == "reasoning": - problem = reasoning.Reasoning() + problem = reasoning.Reasoning(device=device) else: raise ValueError diff --git a/reasoning.py b/reasoning.py index 003806a..b8d39ee 100755 --- a/reasoning.py +++ b/reasoning.py @@ -32,13 +32,12 @@ class Reasoning(problem.Problem): ("gray", [192, 192, 192]), ] - def __init__( - self, - ): + def __init__(self, device=torch.device("cpu")): self.colors = torch.tensor([c for _, c in self.named_colors]) self.name2color = dict([(p[0], i) for i, p in enumerate(self.named_colors)]) self.height = 10 self.width = 10 + self.device = device ###################################################################### @@ -171,6 +170,68 @@ class Reasoning(problem.Problem): return len(self.colors) def rec_coo(self, x, n, min_height=3, min_width=3): + K = 3 + N = 4000 + + while True: + v = ( + ( + torch.rand(N * K, self.height + 1, device=self.device) + .sort(dim=-1) + .indices + < 2 + ) + .long() + .cumsum(dim=1) + == 1 + ).long() + + h = ( + ( + torch.rand(N * K, self.width + 1, device=self.device) + .sort(dim=-1) + .indices + < 2 + ) + .long() + .cumsum(dim=1) + == 1 + ).long() + + i = torch.logical_and( + v.sum(dim=-1) >= min_height, h.sum(dim=-1) >= min_width + ) + + v, h = v[i], h[i] + v = v[: v.size(0) - v.size(0) % K] + h = h[: h.size(0) - h.size(0) % K] + v = v.reshape(v.size(0) // K, K, -1) + h = h.reshape(h.size(0) // K, K, -1) + + r = v[:, :, :, None] * h[:, :, None, :] + + valid = r.sum(dim=1).flatten(1).max(dim=-1).values == 1 + + v = v[valid] + h = h[valid] + + if v.size(0) > 0: + break + + av = torch.arange(v.size(2), device=self.device)[None, :] + ah = torch.arange(h.size(2), device=self.device)[None, :] + + return [ + (i1.item(), j1.item(), i2.item() + 1, j2.item() + 1) + for i1, j1, i2, j2 in zip( + v.size(2) - (v[0] * (v.size(2) - av)).max(dim=-1).values, + h.size(2) - (h[0] * (h.size(2) - ah)).max(dim=-1).values, + (v[0] * av).max(dim=-1).values, + (h[0] * ah).max(dim=-1).values, + ) + ] + + def rec_coo_(self, x, n, min_height=3, min_width=3): collision = x.new(x.size()) while True: collision[...] = 0 @@ -295,7 +356,7 @@ class Reasoning(problem.Problem): ###################################################################### - def generate_prompts_and_answers(self, nb): + def generate_prompts_and_answers(self, nb, device="cpu"): tasks = [ self.task_replace_color, self.task_move, @@ -304,8 +365,12 @@ class Reasoning(problem.Problem): self.task_frame, self.task_detect, ] - prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64) - answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64) + prompts = torch.zeros( + nb, self.height, self.width * 3, dtype=torch.int64, device=self.device + ) + answers = torch.zeros( + nb, self.height, self.width, dtype=torch.int64, device=self.device + ) w = self.width for prompt, answer in tqdm.tqdm( -- 2.39.5