From 7df735c26454770ab29db78db5a3d03f41245a7c Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 30 Jul 2024 12:20:13 +0200 Subject: [PATCH] Update. --- grids.py | 8 +++----- main.py | 4 +++- quiz_machine.py | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/grids.py b/grids.py index ebc2b0e..8d274ad 100755 --- a/grids.py +++ b/grids.py @@ -342,17 +342,15 @@ class Grids(problem.Problem): ): quizzes = quizzes.to("cpu") - to_reconfigure = [result] + to_reconfigure = [quizzes] if predicted_parts is not None: to_reconfigure.append(predicted_parts) if correct_parts is not None: to_reconfigure.append(correct_parts) - to_reconfigure = self.problem.reconfigure( - to_reconfigure, ("A", "f_A", "B", "f_B") - ) + to_reconfigure = self.reconfigure(to_reconfigure, ("A", "f_A", "B", "f_B")) - result = to_reconfigure.pop(0) + quizzes = to_reconfigure.pop(0) if predicted_parts is not None: predicted_parts = to_reconfigure.pop(0) if correct_parts is not None: diff --git a/main.py b/main.py index 6f543a0..455aa1c 100755 --- a/main.py +++ b/main.py @@ -493,7 +493,9 @@ def save_additional_results(models, science_w_quizzes): mask=mask, ) - predicted_parts = torch.tensor(mask, device=correct.device)[None, :] + predicted_parts = torch.tensor(mask, device=correct.device)[None, :].expand( + correct.size(0), -1 + ) correct = (2 * correct - 1) * (predicted_parts.sum(dim=-1) == 1).long() nb_correct = (correct == 1).long().sum() diff --git a/quiz_machine.py b/quiz_machine.py index 1ff23ed..90879ce 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -172,7 +172,7 @@ class QuizMachine: from_w = torch.full((quizzes.size(0),), True, device=quizzes.device) self.randomize_configuations_inplace( - quizzes, structs=[s for s in self.understood_structures] + quizzes, structs=[s for s, m in self.understood_structures] ) i = torch.randperm(quizzes.size(0), device=quizzes.device) @@ -182,7 +182,7 @@ class QuizMachine: ###################################################################### def make_ar_mask(self, quizzes, struct, mask): - assert struct in [s for s in self.understood_structures] + assert struct in [s for s, m in self.understood_structures] return self.problem.make_ar_mask(quizzes, struct=struct, mask=mask) ###################################################################### @@ -296,7 +296,7 @@ class QuizMachine: ) self.randomize_configuations_inplace( - model.train_w_quizzes, structs=[s for s in self.understood_structures] + model.train_w_quizzes, structs=[s for s, m in self.understood_structures] ) ###################################################################### -- 2.39.5