Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 1 Aug 2024 08:09:33 +0000 (10:09 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 1 Aug 2024 08:09:33 +0000 (10:09 +0200)
grids.py
quiz_machine.py

index 410b538..19b9ce4 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -167,16 +167,23 @@ class Grids(problem.Problem):
         self.check_structure(quizzes, struct)
         return struct
 
-    def inject_noise(self, quizzes, noise, struct, mask):
-        assert self.check_structure(quizzes, struct=struct)
+    def inject_noise(self, quizzes, noise, struct=None, mask=None):
+        assert self.check_structure(quizzes, struct=struct)
         S = self.height * self.width
-        mask = torch.tensor(mask, device=quizzes.device)
-        mask = mask[None, :, None].expand(1, 4, S + 1)
-        mask[:, :, 0] = 0
+
+        # mask = torch.tensor(mask, device=quizzes.device)
+        # mask = mask[None, :, None].expand(1, 4, S + 1)
+        # mask[:, :, 0] = 0
+        # mask = mask.reshape(1, -1).expand_as(quizzes)
+
+        mask = quizzes.new_full(quizzes.size(), 1)
+        mask[:, 0 * (S + 1)] = 0
+        mask[:, 1 * (S + 1)] = 0
+        mask[:, 2 * (S + 1)] = 0
+        mask[:, 3 * (S + 1)] = 0
         mask = mask * (torch.rand(mask.size(), device=mask.device) <= noise).long()
-        mask = mask.reshape(1, -1).expand_as(quizzes)
-        random = torch.randint(self.nb_colors, mask.size())
 
+        random = torch.randint(self.nb_colors, mask.size())
         quizzes = mask * random + (1 - mask) * quizzes
 
         return quizzes
index fb451f2..90ca5e6 100755 (executable)
@@ -181,8 +181,8 @@ class QuizMachine:
             quizzes = self.problem.inject_noise(
                 quizzes,
                 self.prompt_noise,
-                struct=("A", "f_A", "B", "f_B"),
-                mask=(1, 0, 1, 0),
+                struct=("A", "f_A", "B", "f_B"),
+                mask=(1, 0, 1, 0),
             )
 
         self.randomize_configuations_inplace(