):
quizzes = quizzes.to("cpu")
+ if quizzes.size(1) == 4 * self.height * self.width:
+ quizzes = torch.cat(
+ [
+ quizzes.new_zeros(quizzes.size(0), 4, 1),
+ quizzes.reshape(quizzes.size(0), 4, -1),
+ ],
+ dim=2,
+ )
+ quizzes[:, :, 0] = torch.tensor(
+ [self.token_A, self.token_f_A, self.token_B, self.token_f_B]
+ )[None, :]
+ quizzes = quizzes.reshape(quizzes.size(0), -1)
+
to_reconfigure = [quizzes]
if predicted_parts is not None:
to_reconfigure.append(predicted_parts)
device=main_device,
)
-
-diffuser = diffusion.Diffuser(
- mu_T_sampler, args.diffusion_nb_iterations, args.diffusion_proba_corruption
-)
-
######################################################################
log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}")
nb = input.size(0)
masks = input.new_zeros(input.size())
u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4)
- masks.view(nb, 4, -1)[:, :, 1:] = u[:, :, None]
+ masks.view(nb, 4, -1)[...] = u[:, :, None]
masks = add_hints(masks, fraction_with_hints)
- # noise = quiz_machine.problem.pure_noise(nb, input.device)
targets = input
- input = (1 - masks) * targets # + masks * noise
+ input = (1 - masks) * targets
return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
def predict_full(model, input, fraction_with_hints=0.0, local_device=main_device):
- boy_that_s_ugly = input.view(input.size(0), 4, -1)[:, :, 0].clone()
input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1))
nb = input.size(0)
masks = input.new_zeros(input.size())
u = F.one_hot(torch.arange(nb, device=masks.device) % 4, num_classes=4)
- masks.view(nb, 4, -1)[:, :, 1:] = u[:, :, None]
+ masks.view(nb, 4, -1)[...] = u[:, :, None]
masks_with_hints = add_hints(masks, fraction_with_hints)
targets = input
input = (1 - masks_with_hints) * targets
result = predict(model, imt_set, local_device=local_device)
result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1)
- result.view(result.size(0), 4, -1)[:, :, 0] = boy_that_s_ugly
-
return result
proba_erased = 1 - (1 - args.diffusion_proba_corruption) ** t
mask_erased = (r <= proba_erased[:, None]).long()
- noise = quiz_machine.problem.pure_noise(nb, input.device)
-
+ noise = quiz_machine.pure_noise(nb, input.device)
targets = input
input = (1 - mask_erased) * input + mask_erased * noise
masks = input.new_full(input.size(), 1)
- masks.reshape(masks.size(0), 4, -1)[:, :, 0] = 0
return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
def generate(model, nb, local_device=main_device):
- all_input = quiz_machine.problem.pure_noise(nb, local_device)
+ all_input = quiz_machine.pure_noise(nb, local_device)
all_masks = all_input.new_full(all_input.size(), 1)
- all_masks.reshape(all_masks.size(0), 4, -1)[:, :, 0] = 0
for input, masks in tqdm.tqdm(
zip(
######################################################################
+ def pure_noise(self, nb, device):
+ r = self.problem.pure_noise(nb, device)
+ r = r.view(r.size(0), 4, -1)[:, :, 1:].reshape(r.size(0), -1)
+ return r
+
def quiz_set(self, nb_samples, c_quizzes, c_quiz_multiplier=1):
if c_quizzes is None:
quizzes = self.problem.generate_w_quizzes(nb_samples)
i = torch.randperm(quizzes.size(0), device=quizzes.device)
quizzes = quizzes[i].contiguous()
- quizzes = quizzes.view(quizzes.size(0), 4, -1)[:, :, 1:].contiguous()
+ quizzes = quizzes.view(quizzes.size(0), 4, -1)[:, :, 1:].reshape(
+ quizzes.size(0), -1
+ )
return quizzes