X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=reasoning.py;h=c4429471bbe907b85de91a9fb9c3c01160d8b2c3;hb=167c56ace610c3b975c702203bb7c7ddf74930ae;hp=92699e82f50f06e9edcd156a84d1e02366e8ace0;hpb=3060e09fe9c6d71f44482308c5876078c527bd70;p=culture.git diff --git a/reasoning.py b/reasoning.py index 92699e8..c442947 100755 --- a/reasoning.py +++ b/reasoning.py @@ -32,13 +32,12 @@ class Reasoning(problem.Problem): ("gray", [192, 192, 192]), ] - def __init__( - self, - ): + def __init__(self, device=torch.device("cpu")): self.colors = torch.tensor([c for _, c in self.named_colors]) self.name2color = dict([(p[0], i) for i, p in enumerate(self.named_colors)]) self.height = 10 self.width = 10 + self.device = device ###################################################################### @@ -170,7 +169,73 @@ class Reasoning(problem.Problem): def nb_token_values(self): return len(self.colors) + # That's quite a tensorial spaghetti mess to sample + # non-overlapping rectangles quickly, but made the generation of + # 100k samples from 1h50 with a lame pure python code to 4min with + # this one. def rec_coo(self, x, n, min_height=3, min_width=3): + K = 3 + N = 1000 + + while True: + v = ( + ( + torch.rand(N * K, self.height + 1, device=self.device) + .sort(dim=-1) + .indices + < 2 + ) + .long() + .cumsum(dim=1) + == 1 + ).long() + + h = ( + ( + torch.rand(N * K, self.width + 1, device=self.device) + .sort(dim=-1) + .indices + < 2 + ) + .long() + .cumsum(dim=1) + == 1 + ).long() + + i = torch.logical_and( + v.sum(dim=-1) >= min_height, h.sum(dim=-1) >= min_width + ) + + v, h = v[i], h[i] + v = v[: v.size(0) - v.size(0) % K] + h = h[: h.size(0) - h.size(0) % K] + v = v.reshape(v.size(0) // K, K, -1) + h = h.reshape(h.size(0) // K, K, -1) + + r = v[:, :, :, None] * h[:, :, None, :] + + valid = r.sum(dim=1).flatten(1).max(dim=-1).values == 1 + + v = v[valid] + h = h[valid] + + if v.size(0) > 0: + break + + av = torch.arange(v.size(2), device=self.device)[None, :] + ah = torch.arange(h.size(2), device=self.device)[None, :] + + return [ + (i1.item(), j1.item(), i2.item() + 1, j2.item() + 1) + for i1, j1, i2, j2 in zip( + v.size(2) - (v[0] * (v.size(2) - av)).max(dim=-1).values, + h.size(2) - (h[0] * (h.size(2) - ah)).max(dim=-1).values, + (v[0] * av).max(dim=-1).values, + (h[0] * ah).max(dim=-1).values, + ) + ] + + def rec_coo_(self, x, n, min_height=3, min_width=3): collision = x.new(x.size()) while True: collision[...] = 0 @@ -295,7 +360,7 @@ class Reasoning(problem.Problem): ###################################################################### - def generate_prompts_and_answers(self, nb): + def generate_prompts_and_answers(self, nb, device="cpu"): tasks = [ self.task_replace_color, self.task_move, @@ -307,13 +372,20 @@ class Reasoning(problem.Problem): prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64) answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64) w = self.width - for prompt, answer in zip(prompts, answers): + + for prompt, answer in tqdm.tqdm( + zip(prompts, answers), + dynamic_ncols=True, + desc="world generation", + total=prompts.size(0), + ): A = prompt[:, 0 * w : 1 * w] f_A = prompt[:, 1 * w : 2 * w] B = prompt[:, 2 * w : 3 * w] f_B = answer task = tasks[torch.randint(len(tasks), (1,))] task(A, f_A, B, f_B) + return prompts.flatten(1), answers.flatten(1) def save_quizzes(