From 30c76210e3ed2704b2a059208f385cb623c1486d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 5 Jul 2024 21:56:42 +0300 Subject: [PATCH] Update. --- quizz_machine.py | 30 ++++++++++++++++++++++------ reasoning.py | 52 ++++++++++++++++++++++++++++++++++++------------ sky.py | 12 ++++++----- 3 files changed, 70 insertions(+), 24 deletions(-) diff --git a/quizz_machine.py b/quizz_machine.py index 62ae8ce..632c9ae 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -238,10 +238,17 @@ class QuizzMachine: result_dir, "culture_w_quizzes", self.train_w_quizzes[:72], - prediction=True, + show_to_be_predicted=True, ) - def save_quizzes(self, result_dir, filename_prefix, quizzes, prediction=False): + def save_quizzes( + self, + result_dir, + filename_prefix, + quizzes, + show_to_be_predicted=False, + mistakes=None, + ): quizzes = quizzes.clone() forward = quizzes[quizzes[:, 0] == self.token_forward] ib = quizzes[:, 0] == self.token_backward @@ -249,9 +256,17 @@ class QuizzMachine: assert forward.size(0) + backward.size(0) == quizzes.size(0) quizzes[ib] = self.reverse_time(quizzes[ib]) - if prediction: - predicted_prompts = ib - predicted_answers = torch.logical_not(ib) + if show_to_be_predicted: + predicted_prompts = ib.long() + predicted_answers = 1 - predicted_prompts + if mistakes is not None: + # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct + predicted_prompts *= mistakes + predicted_answers *= mistakes + else: + # 0/2 ~ not-to-predict / to predict + predicted_prompts *= 2 + predicted_answers *= 2 else: predicted_prompts = None predicted_answers = None @@ -409,11 +424,14 @@ class QuizzMachine: device=self.device, ) + mistakes = (input == result).flatten(1).long().min(dim=1).values * 2 - 1 + self.save_quizzes( result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}", quizzes=result[:72], - prediction=True, + show_to_be_predicted=True, + mistakes=mistakes[:72], ) return main_test_accuracy diff --git a/reasoning.py b/reasoning.py index cd726cb..5499bdf 100755 --- a/reasoning.py +++ b/reasoning.py @@ -87,6 +87,7 @@ class Reasoning(problem.Problem): answers, predicted_prompts=None, predicted_answers=None, + nrow=4, ): prompts = prompts.reshape(prompts.size(0), self.height, -1) answers = answers.reshape(answers.size(0), self.height, -1) @@ -114,9 +115,13 @@ class Reasoning(problem.Problem): y[...] = c else: c = c.long()[:, None] - c = c * torch.tensor([192, 192, 192], device=c.device) + ( - 1 - c - ) * torch.tensor([255, 255, 255], device=c.device) + c = ( + (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long())) + * torch.tensor([192, 192, 192], device=c.device) + + (c == 1).long() * torch.tensor([0, 255, 0], device=c.device) + + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device) + + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device) + ) y[...] = c[:, :, None, None] y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x @@ -186,7 +191,11 @@ class Reasoning(problem.Problem): image_name = os.path.join(result_dir, filename) torchvision.utils.save_image( - img.float() / 255.0, image_name, nrow=4, padding=margin * 4, pad_value=1.0 + img.float() / 255.0, + image_name, + nrow=nrow, + padding=margin * 4, + pad_value=1.0, ) ###################################################################### @@ -581,8 +590,8 @@ class Reasoning(problem.Problem): ###################################################################### - def generate_prompts_and_answers(self, nb, device="cpu"): - tasks = [ + def all_tasks(self): + return [ self.task_replace_color, self.task_translate, self.task_grow, @@ -594,6 +603,11 @@ class Reasoning(problem.Problem): self.task_bounce, self.task_scale, ] + + def generate_prompts_and_answers(self, nb, tasks=None, device="cpu"): + if tasks is None: + tasks = self.all_tasks() + prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64) answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64) w = self.width @@ -621,6 +635,7 @@ class Reasoning(problem.Problem): answers, predicted_prompts=None, predicted_answers=None, + nrow=4, ): self.save_image( result_dir, @@ -629,6 +644,7 @@ class Reasoning(problem.Problem): answers, predicted_prompts, predicted_answers, + nrow, ) @@ -637,22 +653,32 @@ class Reasoning(problem.Problem): if __name__ == "__main__": import time + nb = 4 + reasoning = Reasoning() + for t in reasoning.all_tasks(): + print(t.__name__) + prompts, answers = reasoning.generate_prompts_and_answers(nb, tasks=[t]) + reasoning.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=1) + + exit(0) + start_time = time.perf_counter() - prompts, answers = reasoning.generate_prompts_and_answers(100) + prompts, answers = reasoning.generate_prompts_and_answers(nb) delay = time.perf_counter() - start_time print(f"{prompts.size(0)/delay:02f} seq/s") - predicted_prompts = torch.rand(prompts.size(0)) < 0.5 - predicted_answers = torch.logical_not(predicted_prompts) + # m = torch.randint(2, (prompts.size(0),)) + # predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1) + # predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1) reasoning.save_quizzes( "/tmp", "test", - prompts[:64], - answers[:64], + prompts[:nb], + answers[:nb], # You can add a bool to put a frame around the predicted parts - # predicted_prompts[:64], - # predicted_answers[:64], + # predicted_prompts[:nb], + # predicted_answers[:nb], ) diff --git a/sky.py b/sky.py index 6ef8a3a..ed440d3 100755 --- a/sky.py +++ b/sky.py @@ -217,9 +217,11 @@ class Sky(problem.Problem): y[...] = c else: c = c.long()[:, None] - c = c * torch.tensor([0, 0, 0], device=c.device) + ( - 1 - c - ) * torch.tensor([255, 255, 255], device=c.device) + c = ( + (c == 1).long() * torch.tensor([0, 255, 0], device=c.device) + + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device) + + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device) + ) y[...] = c[:, :, None, None] y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x @@ -322,8 +324,8 @@ if __name__ == "__main__": prompts, answers = sky.generate_prompts_and_answers(4) - predicted_prompts = torch.rand(prompts.size(0)) < 0.5 - predicted_answers = torch.rand(answers.size(0)) < 0.5 + predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1 + predicted_answers = torch.randint(3, (prompts.size(0),)) - 1 sky.save_quizzes( "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers -- 2.20.1