S = self.height * self.width
mask = torch.tensor(mask, device=quizzes.device)
mask = mask[None, :, None].expand(1, 4, S + 1)
+ mask[:, :, 0] = 0
mask = mask * (torch.rand(mask.size(), device=mask.device) <= noise).long()
mask = mask.reshape(1, -1).expand_as(quizzes)
random = torch.randint(self.nb_colors, mask.size())
if self.prompt_noise > 0.0:
quizzes = self.problem.inject_noise(
- quizzes, self.prompt_noise, ("A", "f_A", "B", "f_B"), (1, 0, 1, 0)
+ quizzes,
+ self.prompt_noise,
+ struct=("A", "f_A", "B", "f_B"),
+ mask=(1, 0, 1, 0),
)
self.randomize_configuations_inplace(
model.train_w_quizzes.size(0)
)
- self.randomize_configuations_inplace(
- model.train_w_quizzes, structs=[s for s, m in self.understood_structures]
- )
-
######################################################################
def store_c_quizzes(self, new_c_quizzes, for_train=True):