Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 1 Aug 2024 07:12:49 +0000 (09:12 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 1 Aug 2024 07:12:49 +0000 (09:12 +0200)
grids.py
quiz_machine.py

index f12fcb9..410b538 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -172,6 +172,7 @@ class Grids(problem.Problem):
         S = self.height * self.width
         mask = torch.tensor(mask, device=quizzes.device)
         mask = mask[None, :, None].expand(1, 4, S + 1)
+        mask[:, :, 0] = 0
         mask = mask * (torch.rand(mask.size(), device=mask.device) <= noise).long()
         mask = mask.reshape(1, -1).expand_as(quizzes)
         random = torch.randint(self.nb_colors, mask.size())
index 1e973c5..fb451f2 100755 (executable)
@@ -179,7 +179,10 @@ class QuizMachine:
 
         if self.prompt_noise > 0.0:
             quizzes = self.problem.inject_noise(
-                quizzes, self.prompt_noise, ("A", "f_A", "B", "f_B"), (1, 0, 1, 0)
+                quizzes,
+                self.prompt_noise,
+                struct=("A", "f_A", "B", "f_B"),
+                mask=(1, 0, 1, 0),
             )
 
         self.randomize_configuations_inplace(
@@ -304,10 +307,6 @@ class QuizMachine:
                 model.train_w_quizzes.size(0)
             )
 
-        self.randomize_configuations_inplace(
-            model.train_w_quizzes, structs=[s for s, m in self.understood_structures]
-        )
-
     ######################################################################
 
     def store_c_quizzes(self, new_c_quizzes, for_train=True):