+def predict_the_four_grids(
+ model, input, with_noise=False, with_hints=False, local_device=main_device
+):
+ input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1))
+ nb = input.size(0)
+ masks = input.new_zeros(input.size())
+ u = F.one_hot(torch.arange(nb, device=masks.device) % 4, num_classes=4)
+ masks.view(nb, 4, -1)[...] = u[:, :, None]
+ targets = input
+ input = (1 - masks) * targets
+ imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
+ if with_hints:
+ imt_set = add_hints_imt(imt_set)
+
+ if with_noise:
+ imt_set = add_noise_imt(imt_set)
+
+ result = ae_predict(model, imt_set, local_device=local_device)
+ result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1)
+
+ return result
+
+
+######################################################################
+
+
+def samples_for_generation_imt(input):
+ nb = input.size(0)
+ probs_iterations = 0.1 ** torch.linspace(
+ 0, 1, args.diffusion_nb_iterations, device=input.device
+ )
+ probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
+ probs_iterations = probs_iterations.expand(nb, -1)
+ dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
+ t = dist.sample() + 1
+ r = torch.rand(input.size(), device=input.device)
+ proba_erased = 1 - (1 - args.diffusion_proba_corruption) ** t
+ mask_erased = (r <= proba_erased[:, None]).long()
+
+ noise = problem.pure_noise(nb, input.device)
+ targets = input
+ input = (1 - mask_erased) * input + mask_erased * noise
+ masks = input.new_full(input.size(), 1)
+
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
+
+def prioritized_rand(low):
+ x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values
+ k = torch.rand(low.size(), device=low.device) + low.long()
+ k = k.sort(dim=1).indices
+ y = x.new(x.size())
+ y.scatter_(dim=1, index=k, src=x)
+ return y
+
+
+def ae_generate(model, nb, local_device=main_device):
+ model.eval().to(local_device)
+
+ # We loop through the iterations first and through the
+ # mini-batches second so that we keep only the samples that have
+ # not stabilized
+
+ all_input = problem.pure_noise(nb, local_device)
+ all_masks = all_input.new_full(all_input.size(), 1)
+ all_changed = torch.full((all_input.size(0),), True, device=all_input.device)
+
+ for it in range(args.diffusion_nb_iterations):
+ # log_string(f"nb_changed {all_changed.long().sum().item()}")
+
+ if not all_changed.any():
+ break
+
+ sub_input = all_input[all_changed].clone()
+ sub_masks = all_masks[all_changed].clone()
+ sub_changed = all_changed[all_changed].clone()
+
+ src = zip(
+ sub_input.split(args.eval_batch_size),
+ sub_masks.split(args.eval_batch_size),
+ sub_changed.split(args.eval_batch_size),