self.check_structure(quizzes, struct)
return struct
- def inject_noise(self, quizzes, noise): # , struct=None, mask=None):
- # assert self.check_structure(quizzes, struct=struct)
+ def inject_noise(self, quizzes, noise, struct, mask):
+ assert self.check_structure(quizzes, struct=struct)
S = self.height * self.width
- # mask = torch.tensor(mask, device=quizzes.device)
- # mask = mask[None, :, None].expand(1, 4, S + 1)
- # mask[:, :, 0] = 0
- # mask = mask.reshape(1, -1).expand_as(quizzes)
+ mask = torch.tensor(mask, device=quizzes.device)
+ mask = mask[None, :, None].expand(1, 4, S + 1)
+ mask[:, :, 0] = 0
+ mask = mask.reshape(1, -1).expand_as(quizzes)
- mask = quizzes.new_full(quizzes.size(), 1)
- mask[:, 0 * (S + 1)] = 0
- mask[:, 1 * (S + 1)] = 0
- mask[:, 2 * (S + 1)] = 0
- mask[:, 3 * (S + 1)] = 0
mask = mask * (torch.rand(mask.size(), device=mask.device) <= noise).long()
-
random = torch.randint(self.nb_colors, mask.size())
+
quizzes = mask * random + (1 - mask) * quizzes
return quizzes
i = torch.randperm(quizzes.size(0), device=quizzes.device)
quizzes, from_w = quizzes[i], from_w[i]
- if self.prompt_noise > 0.0:
- quizzes = self.problem.inject_noise(
- quizzes,
- self.prompt_noise,
- # struct=("A", "f_A", "B", "f_B"),
- # mask=(1, 0, 1, 0),
- )
-
self.randomize_configuations_inplace(
quizzes, structs=[s for s, m in self.understood_structures]
)
+ if self.prompt_noise > 0.0:
+ for struct, mask in self.understood_structures:
+ i = self.problem.indices_select(quizzes=input, struct=struct)
+ input[i] = self.problem.inject_noise(
+ input[i],
+ self.prompt_noise,
+ struct=struct,
+ mask=(1 - k for k in mask),
+ )
+
return quizzes, from_w
######################################################################