Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 1 Aug 2024 09:51:54 +0000 (11:51 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 1 Aug 2024 09:51:54 +0000 (11:51 +0200)
quiz_machine.py

index bfa7f97..a042431 100755 (executable)
@@ -82,11 +82,11 @@ class QuizMachine:
         self.prompt_noise = prompt_noise
 
         self.understood_structures = [
-            (("A", "f_A", "B", "f_B"), (0, 0, 0, 1)),
-            (("f_A", "A", "f_B", "B"), (0, 0, 0, 1)),
-            (("B", "f_B", "A", "f_A"), (0, 0, 0, 1)),
-            (("f_B", "B", "f_A", "A"), (0, 0, 0, 1)),
-            (("f_B", "f_A", "A", "B"), (0, 1, 1, 1)),
+            (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)),
+            (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)),
+            (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 0, 0)),
+            (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 0, 0)),
+            (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0)),
         ]
 
         self.LOCK_C_QUIZZES = threading.Lock()
@@ -178,18 +178,15 @@ class QuizMachine:
         quizzes, from_w = quizzes[i], from_w[i]
 
         self.randomize_configuations_inplace(
-            quizzes, structs=[s for s, m in self.understood_structures]
+            quizzes, structs=[s for s, m, _ in self.understood_structures]
         )
 
         if self.prompt_noise > 0.0:
-            for struct, mask in self.understood_structures:
+            for struct, mask, noise_mask in self.understood_structures:
                 i = self.problem.indices_select(quizzes=quizzes, struct=struct)
                 if i.any():
                     quizzes[i] = self.problem.inject_noise(
-                        quizzes[i],
-                        self.prompt_noise,
-                        struct=struct,
-                        mask=tuple(1 - k for k in mask),
+                        quizzes[i], self.prompt_noise, struct=struct, mask=noise_mask
                     )
 
         return quizzes, from_w
@@ -197,7 +194,7 @@ class QuizMachine:
     ######################################################################
 
     def make_ar_mask(self, quizzes, struct, mask):
-        assert struct in [s for s, m in self.understood_structures]
+        assert struct in [s for s, _, _ in self.understood_structures]
         return self.problem.make_ar_mask(quizzes, struct=struct, mask=mask)
 
     ######################################################################
@@ -231,7 +228,7 @@ class QuizMachine:
         nb = 0
 
         # We consider all the configurations that we train for
-        for struct, mask in self.understood_structures:
+        for struct, mask, noise_mask in self.understood_structures:
             i = self.problem.indices_select(quizzes=input, struct=struct)
             nb += i.long().sum()
             result[i], correct[i] = self.predict(