X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=reasoning.py;h=5499bdfd6985081a2f8db5bf1ace853fcc354768;hb=30c76210e3ed2704b2a059208f385cb623c1486d;hp=cd726cb83011e5f566d3764f8e156d09afdd4b3f;hpb=cc241681730e50ad149a68c612e3a06f2d4a71be;p=culture.git diff --git a/reasoning.py b/reasoning.py index cd726cb..5499bdf 100755 --- a/reasoning.py +++ b/reasoning.py @@ -87,6 +87,7 @@ class Reasoning(problem.Problem): answers, predicted_prompts=None, predicted_answers=None, + nrow=4, ): prompts = prompts.reshape(prompts.size(0), self.height, -1) answers = answers.reshape(answers.size(0), self.height, -1) @@ -114,9 +115,13 @@ class Reasoning(problem.Problem): y[...] = c else: c = c.long()[:, None] - c = c * torch.tensor([192, 192, 192], device=c.device) + ( - 1 - c - ) * torch.tensor([255, 255, 255], device=c.device) + c = ( + (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long())) + * torch.tensor([192, 192, 192], device=c.device) + + (c == 1).long() * torch.tensor([0, 255, 0], device=c.device) + + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device) + + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device) + ) y[...] = c[:, :, None, None] y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x @@ -186,7 +191,11 @@ class Reasoning(problem.Problem): image_name = os.path.join(result_dir, filename) torchvision.utils.save_image( - img.float() / 255.0, image_name, nrow=4, padding=margin * 4, pad_value=1.0 + img.float() / 255.0, + image_name, + nrow=nrow, + padding=margin * 4, + pad_value=1.0, ) ###################################################################### @@ -581,8 +590,8 @@ class Reasoning(problem.Problem): ###################################################################### - def generate_prompts_and_answers(self, nb, device="cpu"): - tasks = [ + def all_tasks(self): + return [ self.task_replace_color, self.task_translate, self.task_grow, @@ -594,6 +603,11 @@ class Reasoning(problem.Problem): self.task_bounce, self.task_scale, ] + + def generate_prompts_and_answers(self, nb, tasks=None, device="cpu"): + if tasks is None: + tasks = self.all_tasks() + prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64) answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64) w = self.width @@ -621,6 +635,7 @@ class Reasoning(problem.Problem): answers, predicted_prompts=None, predicted_answers=None, + nrow=4, ): self.save_image( result_dir, @@ -629,6 +644,7 @@ class Reasoning(problem.Problem): answers, predicted_prompts, predicted_answers, + nrow, ) @@ -637,22 +653,32 @@ class Reasoning(problem.Problem): if __name__ == "__main__": import time + nb = 4 + reasoning = Reasoning() + for t in reasoning.all_tasks(): + print(t.__name__) + prompts, answers = reasoning.generate_prompts_and_answers(nb, tasks=[t]) + reasoning.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=1) + + exit(0) + start_time = time.perf_counter() - prompts, answers = reasoning.generate_prompts_and_answers(100) + prompts, answers = reasoning.generate_prompts_and_answers(nb) delay = time.perf_counter() - start_time print(f"{prompts.size(0)/delay:02f} seq/s") - predicted_prompts = torch.rand(prompts.size(0)) < 0.5 - predicted_answers = torch.logical_not(predicted_prompts) + # m = torch.randint(2, (prompts.size(0),)) + # predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1) + # predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1) reasoning.save_quizzes( "/tmp", "test", - prompts[:64], - answers[:64], + prompts[:nb], + answers[:nb], # You can add a bool to put a frame around the predicted parts - # predicted_prompts[:64], - # predicted_answers[:64], + # predicted_prompts[:nb], + # predicted_answers[:nb], )