X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=reasoning.py;h=c4429471bbe907b85de91a9fb9c3c01160d8b2c3;hb=167c56ace610c3b975c702203bb7c7ddf74930ae;hp=003806a8070d50a272200e5e3bb8d9a59b13dbfb;hpb=0be6757c554ab40b08b4acfd90787a86f4c4cc5b;p=culture.git diff --git a/reasoning.py b/reasoning.py index 003806a..c442947 100755 --- a/reasoning.py +++ b/reasoning.py @@ -32,13 +32,12 @@ class Reasoning(problem.Problem): ("gray", [192, 192, 192]), ] - def __init__( - self, - ): + def __init__(self, device=torch.device("cpu")): self.colors = torch.tensor([c for _, c in self.named_colors]) self.name2color = dict([(p[0], i) for i, p in enumerate(self.named_colors)]) self.height = 10 self.width = 10 + self.device = device ###################################################################### @@ -170,7 +169,73 @@ class Reasoning(problem.Problem): def nb_token_values(self): return len(self.colors) + # That's quite a tensorial spaghetti mess to sample + # non-overlapping rectangles quickly, but made the generation of + # 100k samples from 1h50 with a lame pure python code to 4min with + # this one. def rec_coo(self, x, n, min_height=3, min_width=3): + K = 3 + N = 1000 + + while True: + v = ( + ( + torch.rand(N * K, self.height + 1, device=self.device) + .sort(dim=-1) + .indices + < 2 + ) + .long() + .cumsum(dim=1) + == 1 + ).long() + + h = ( + ( + torch.rand(N * K, self.width + 1, device=self.device) + .sort(dim=-1) + .indices + < 2 + ) + .long() + .cumsum(dim=1) + == 1 + ).long() + + i = torch.logical_and( + v.sum(dim=-1) >= min_height, h.sum(dim=-1) >= min_width + ) + + v, h = v[i], h[i] + v = v[: v.size(0) - v.size(0) % K] + h = h[: h.size(0) - h.size(0) % K] + v = v.reshape(v.size(0) // K, K, -1) + h = h.reshape(h.size(0) // K, K, -1) + + r = v[:, :, :, None] * h[:, :, None, :] + + valid = r.sum(dim=1).flatten(1).max(dim=-1).values == 1 + + v = v[valid] + h = h[valid] + + if v.size(0) > 0: + break + + av = torch.arange(v.size(2), device=self.device)[None, :] + ah = torch.arange(h.size(2), device=self.device)[None, :] + + return [ + (i1.item(), j1.item(), i2.item() + 1, j2.item() + 1) + for i1, j1, i2, j2 in zip( + v.size(2) - (v[0] * (v.size(2) - av)).max(dim=-1).values, + h.size(2) - (h[0] * (h.size(2) - ah)).max(dim=-1).values, + (v[0] * av).max(dim=-1).values, + (h[0] * ah).max(dim=-1).values, + ) + ] + + def rec_coo_(self, x, n, min_height=3, min_width=3): collision = x.new(x.size()) while True: collision[...] = 0 @@ -295,7 +360,7 @@ class Reasoning(problem.Problem): ###################################################################### - def generate_prompts_and_answers(self, nb): + def generate_prompts_and_answers(self, nb, device="cpu"): tasks = [ self.task_replace_color, self.task_move, @@ -320,6 +385,7 @@ class Reasoning(problem.Problem): f_B = answer task = tasks[torch.randint(len(tasks), (1,))] task(A, f_A, B, f_B) + return prompts.flatten(1), answers.flatten(1) def save_quizzes(