parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
-parser.add_argument("--prompt_noise_proba", type=float, default=0.05)
+parser.add_argument("--proba_prompt_noise", type=float, default=0.05)
-parser.add_argument("--hint_proba", type=float, default=0.01)
+parser.add_argument("--proba_hint", type=float, default=0.01)
# parser.add_argument("--nb_hints", type=int, default=25)
# t = h.sort(dim=1).values[:, args.nb_hints, None]
# mask_hints = (h < t).long()
mask_hints = (
- torch.rand(input.size(), device=input.device) < args.hint_proba
+ torch.rand(input.size(), device=input.device) < args.proba_hint
).long() * masks
masks = (1 - mask_hints) * masks
input = (1 - mask_hints) * input + mask_hints * targets
# Make pixels from the available input (mask=0) noise with probability
-# args.prompt_noise_proba
+# args.proba_prompt_noise
def add_noise(imt_set):
input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
noise = quiz_machine.pure_noise(input.size(0), input.device)
change = (1 - masks) * (
- torch.rand(input.size(), device=input.device) < args.prompt_noise_proba
+ torch.rand(input.size(), device=input.device) < args.proba_prompt_noise
).long()
input = (1 - change) * input + change * noise
return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)