self.check_structure(quizzes, struct)
return struct
- def inject_noise(self, quizzes, noise, struct, mask):
- assert self.check_structure(quizzes, struct=struct)
+ def inject_noise(self, quizzes, noise, struct=None, mask=None):
+ # assert self.check_structure(quizzes, struct=struct)
S = self.height * self.width
- mask = torch.tensor(mask, device=quizzes.device)
- mask = mask[None, :, None].expand(1, 4, S + 1)
- mask[:, :, 0] = 0
+
+ # mask = torch.tensor(mask, device=quizzes.device)
+ # mask = mask[None, :, None].expand(1, 4, S + 1)
+ # mask[:, :, 0] = 0
+ # mask = mask.reshape(1, -1).expand_as(quizzes)
+
+ mask = quizzes.new_full(quizzes.size(), 1)
+ mask[:, 0 * (S + 1)] = 0
+ mask[:, 1 * (S + 1)] = 0
+ mask[:, 2 * (S + 1)] = 0
+ mask[:, 3 * (S + 1)] = 0
mask = mask * (torch.rand(mask.size(), device=mask.device) <= noise).long()
- mask = mask.reshape(1, -1).expand_as(quizzes)
- random = torch.randint(self.nb_colors, mask.size())
+ random = torch.randint(self.nb_colors, mask.size())
quizzes = mask * random + (1 - mask) * quizzes
return quizzes
quizzes = self.problem.inject_noise(
quizzes,
self.prompt_noise,
- struct=("A", "f_A", "B", "f_B"),
- mask=(1, 0, 1, 0),
+ # struct=("A", "f_A", "B", "f_B"),
+ # mask=(1, 0, 1, 0),
)
self.randomize_configuations_inplace(