Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 1 Aug 2024 08:18:14 +0000 (10:18 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 1 Aug 2024 08:18:14 +0000 (10:18 +0200)
grids.py
quiz_machine.py

index 9b1ed50..f747c47 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -167,23 +167,18 @@ class Grids(problem.Problem):
         self.check_structure(quizzes, struct)
         return struct
 
-    def inject_noise(self, quizzes, noise):  # , struct=None, mask=None):
-        assert self.check_structure(quizzes, struct=struct)
+    def inject_noise(self, quizzes, noise, struct, mask):
+        assert self.check_structure(quizzes, struct=struct)
         S = self.height * self.width
 
-        mask = torch.tensor(mask, device=quizzes.device)
-        mask = mask[None, :, None].expand(1, 4, S + 1)
-        mask[:, :, 0] = 0
-        mask = mask.reshape(1, -1).expand_as(quizzes)
+        mask = torch.tensor(mask, device=quizzes.device)
+        mask = mask[None, :, None].expand(1, 4, S + 1)
+        mask[:, :, 0] = 0
+        mask = mask.reshape(1, -1).expand_as(quizzes)
 
-        mask = quizzes.new_full(quizzes.size(), 1)
-        mask[:, 0 * (S + 1)] = 0
-        mask[:, 1 * (S + 1)] = 0
-        mask[:, 2 * (S + 1)] = 0
-        mask[:, 3 * (S + 1)] = 0
         mask = mask * (torch.rand(mask.size(), device=mask.device) <= noise).long()
-
         random = torch.randint(self.nb_colors, mask.size())
+
         quizzes = mask * random + (1 - mask) * quizzes
 
         return quizzes
index 90ca5e6..a4ca60b 100755 (executable)
@@ -177,18 +177,20 @@ class QuizMachine:
         i = torch.randperm(quizzes.size(0), device=quizzes.device)
         quizzes, from_w = quizzes[i], from_w[i]
 
-        if self.prompt_noise > 0.0:
-            quizzes = self.problem.inject_noise(
-                quizzes,
-                self.prompt_noise,
-                # struct=("A", "f_A", "B", "f_B"),
-                # mask=(1, 0, 1, 0),
-            )
-
         self.randomize_configuations_inplace(
             quizzes, structs=[s for s, m in self.understood_structures]
         )
 
+        if self.prompt_noise > 0.0:
+            for struct, mask in self.understood_structures:
+                i = self.problem.indices_select(quizzes=input, struct=struct)
+                input[i] = self.problem.inject_noise(
+                    input[i],
+                    self.prompt_noise,
+                    struct=struct,
+                    mask=(1 - k for k in mask),
+                )
+
         return quizzes, from_w
 
     ######################################################################