From f848e5847554870b26e6219e33c845669f4663b3 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 1 Aug 2024 11:51:54 +0200 Subject: [PATCH] Update. --- quiz_machine.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/quiz_machine.py b/quiz_machine.py index bfa7f97..a042431 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -82,11 +82,11 @@ class QuizMachine: self.prompt_noise = prompt_noise self.understood_structures = [ - (("A", "f_A", "B", "f_B"), (0, 0, 0, 1)), - (("f_A", "A", "f_B", "B"), (0, 0, 0, 1)), - (("B", "f_B", "A", "f_A"), (0, 0, 0, 1)), - (("f_B", "B", "f_A", "A"), (0, 0, 0, 1)), - (("f_B", "f_A", "A", "B"), (0, 1, 1, 1)), + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)), + (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)), + (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 0, 0)), + (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 0, 0)), + (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0)), ] self.LOCK_C_QUIZZES = threading.Lock() @@ -178,18 +178,15 @@ class QuizMachine: quizzes, from_w = quizzes[i], from_w[i] self.randomize_configuations_inplace( - quizzes, structs=[s for s, m in self.understood_structures] + quizzes, structs=[s for s, m, _ in self.understood_structures] ) if self.prompt_noise > 0.0: - for struct, mask in self.understood_structures: + for struct, mask, noise_mask in self.understood_structures: i = self.problem.indices_select(quizzes=quizzes, struct=struct) if i.any(): quizzes[i] = self.problem.inject_noise( - quizzes[i], - self.prompt_noise, - struct=struct, - mask=tuple(1 - k for k in mask), + quizzes[i], self.prompt_noise, struct=struct, mask=noise_mask ) return quizzes, from_w @@ -197,7 +194,7 @@ class QuizMachine: ###################################################################### def make_ar_mask(self, quizzes, struct, mask): - assert struct in [s for s, m in self.understood_structures] + assert struct in [s for s, _, _ in self.understood_structures] return self.problem.make_ar_mask(quizzes, struct=struct, mask=mask) ###################################################################### @@ -231,7 +228,7 @@ class QuizMachine: nb = 0 # We consider all the configurations that we train for - for struct, mask in self.understood_structures: + for struct, mask, noise_mask in self.understood_structures: i = self.problem.indices_select(quizzes=input, struct=struct) nb += i.long().sum() result[i], correct[i] = self.predict( -- 2.39.5