+def add_hints_imt(imt_set):
+ """Set every component of the mask to zero with probability
+ args.proba_hint, and for each component set to zero, copy the
+ corresponding value from the target into the input
+
+ """
+ input, masks, targets = imt_set.unbind(dim=1)
+ # h = torch.rand(masks.size(), device=masks.device) - masks
+ # t = h.sort(dim=1).values[:, args.nb_hints, None]
+ # mask_hints = (h < t).long()
+ mask_hints = (
+ torch.rand(input.size(), device=input.device) < args.proba_hint
+ ).long() * masks
+ masks = (1 - mask_hints) * masks
+ input = (1 - mask_hints) * input + mask_hints * targets
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
+
+def add_noise_imt(imt_set):
+ """Replace every component of the input by a random value with
+ probability args.proba_prompt_noise."""
+ input, masks, targets = imt_set.unbind(dim=1)
+ noise = problem.pure_noise(input.size(0), input.device)
+ change = (1 - masks) * (
+ torch.rand(input.size(), device=input.device) < args.proba_prompt_noise
+ ).long()
+ input = (1 - change) * input + change * noise
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
+
+######################################################################
+# Prediction
+
+
+def samples_for_prediction_imt(input):
+ nb = input.size(0)
+ masks = input.new_zeros(input.size())
+ u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4)
+ masks.view(nb, 4, -1)[...] = u[:, :, None]
+ targets = input
+ input = (1 - masks) * targets
+
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
+
+def ae_predict(model, imt_set, local_device=main_device):
+ model.eval().to(local_device)
+
+ record = []
+
+ src = tqdm.tqdm(
+ imt_set.split(args.eval_batch_size),
+ dynamic_ncols=True,
+ desc="predict",
+ total=imt_set.size(0) // args.eval_batch_size,
+ delay=10,