From 83c4c3c1b1843f418b80054de50b0cb330ccbc17 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 30 Jul 2024 11:10:03 +0200 Subject: [PATCH] Update. --- grids.py | 17 +++++++++++++++-- quiz_machine.py | 33 ++++++++++++--------------------- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/grids.py b/grids.py index 296c23a..ebc2b0e 100755 --- a/grids.py +++ b/grids.py @@ -342,8 +342,21 @@ class Grids(problem.Problem): ): quizzes = quizzes.to("cpu") - if not self.check_structure(quizzes, ("A", "f_A", "B", "f_B")): - print(f"**WARNING** {filename} is not in A/f_A/B/f_B order") + to_reconfigure = [result] + if predicted_parts is not None: + to_reconfigure.append(predicted_parts) + if correct_parts is not None: + to_reconfigure.append(correct_parts) + + to_reconfigure = self.problem.reconfigure( + to_reconfigure, ("A", "f_A", "B", "f_B") + ) + + result = to_reconfigure.pop(0) + if predicted_parts is not None: + predicted_parts = to_reconfigure.pop(0) + if correct_parts is not None: + correct_parts = to_reconfigure.pop(0) S = self.height * self.width diff --git a/quiz_machine.py b/quiz_machine.py index 9ca84b3..1ff23ed 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -79,13 +79,12 @@ class QuizMachine: self.prompt_len = None self.answer_len = None - self.train_struct = [ - ("A", "f_A", "B", "f_B"), # The standard order - ("f_A", "A", "f_B", "B"), # The reverse order for validation - ("B", "f_B", "A", "f_A"), - ("f_B", "B", "f_A", "A"), - ("f_B", "f_A", "A", "B"), # The synthesis order - ("f_B", "f_A", "A", "B"), # twice! + self.understood_structures = [ + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1)), + (("f_A", "A", "f_B", "B"), (0, 0, 0, 1)), + (("B", "f_B", "A", "f_A"), (0, 0, 0, 1)), + (("f_B", "B", "f_A", "A"), (0, 0, 0, 1)), + (("f_B", "f_A", "A", "B"), (0, 1, 1, 1)), ] self.LOCK_C_QUIZZES = threading.Lock() @@ -172,7 +171,9 @@ class QuizMachine: quizzes = w_quizzes.clone() from_w = torch.full((quizzes.size(0),), True, device=quizzes.device) - self.randomize_configuations_inplace(quizzes, structs=self.train_struct) + self.randomize_configuations_inplace( + quizzes, structs=[s for s in self.understood_structures] + ) i = torch.randperm(quizzes.size(0), device=quizzes.device) @@ -181,7 +182,7 @@ class QuizMachine: ###################################################################### def make_ar_mask(self, quizzes, struct, mask): - assert struct in self.train_struct + assert struct in [s for s in self.understood_structures] return self.problem.make_ar_mask(quizzes, struct=struct, mask=mask) ###################################################################### @@ -215,13 +216,7 @@ class QuizMachine: nb = 0 # We consider all the configurations that we train for - for struct, mask in [ - (("A", "f_A", "B", "f_B"), (0, 0, 0, 1)), - (("f_A", "A", "f_B", "B"), (0, 0, 0, 1)), - (("B", "f_B", "A", "f_A"), (0, 0, 0, 1)), - (("f_B", "B", "f_A", "A"), (0, 0, 0, 1)), - (("f_B", "f_A", "A", "B"), (0, 1, 1, 1)), - ]: + for struct, mask in self.understood_structures: i = self.problem.indices_select(quizzes=input, struct=struct) nb += i.long().sum() result[i], correct[i] = self.predict( @@ -249,10 +244,6 @@ class QuizMachine: predicted_parts = predicted_parts[:128] correct_parts = correct_parts[:128] - result, predicted_parts, correct_parts = self.problem.reconfigure( - [result, predicted_parts, correct_parts], ("A", "f_A", "B", "f_B") - ) - self.problem.save_quizzes_as_image( result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}.png", @@ -305,7 +296,7 @@ class QuizMachine: ) self.randomize_configuations_inplace( - model.train_w_quizzes, structs=self.train_struct + model.train_w_quizzes, structs=[s for s in self.understood_structures] ) ###################################################################### -- 2.39.5