From 20461c9c5331fa427f7efb1534b6155cab4926dc Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 17 Sep 2024 20:18:24 +0200 Subject: [PATCH] Update. --- grids.py | 13 +++++++++++++ main.py | 22 +++++----------------- quiz_machine.py | 9 ++++++++- 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/grids.py b/grids.py index 490750b..9424496 100755 --- a/grids.py +++ b/grids.py @@ -422,6 +422,19 @@ class Grids(problem.Problem): ): quizzes = quizzes.to("cpu") + if quizzes.size(1) == 4 * self.height * self.width: + quizzes = torch.cat( + [ + quizzes.new_zeros(quizzes.size(0), 4, 1), + quizzes.reshape(quizzes.size(0), 4, -1), + ], + dim=2, + ) + quizzes[:, :, 0] = torch.tensor( + [self.token_A, self.token_f_A, self.token_B, self.token_f_B] + )[None, :] + quizzes = quizzes.reshape(quizzes.size(0), -1) + to_reconfigure = [quizzes] if predicted_parts is not None: to_reconfigure.append(predicted_parts) diff --git a/main.py b/main.py index 9525bdd..6cbb2c4 100755 --- a/main.py +++ b/main.py @@ -326,11 +326,6 @@ quiz_machine = quiz_machine.QuizMachine( device=main_device, ) - -diffuser = diffusion.Diffuser( - mu_T_sampler, args.diffusion_nb_iterations, args.diffusion_proba_corruption -) - ###################################################################### log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}") @@ -412,11 +407,10 @@ def batch_prediction_imt(input, fraction_with_hints=0.0): nb = input.size(0) masks = input.new_zeros(input.size()) u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4) - masks.view(nb, 4, -1)[:, :, 1:] = u[:, :, None] + masks.view(nb, 4, -1)[...] = u[:, :, None] masks = add_hints(masks, fraction_with_hints) - # noise = quiz_machine.problem.pure_noise(nb, input.device) targets = input - input = (1 - masks) * targets # + masks * noise + input = (1 - masks) * targets return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) @@ -446,12 +440,11 @@ def predict(model, imt_set, local_device=main_device): def predict_full(model, input, fraction_with_hints=0.0, local_device=main_device): - boy_that_s_ugly = input.view(input.size(0), 4, -1)[:, :, 0].clone() input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1)) nb = input.size(0) masks = input.new_zeros(input.size()) u = F.one_hot(torch.arange(nb, device=masks.device) % 4, num_classes=4) - masks.view(nb, 4, -1)[:, :, 1:] = u[:, :, None] + masks.view(nb, 4, -1)[...] = u[:, :, None] masks_with_hints = add_hints(masks, fraction_with_hints) targets = input input = (1 - masks_with_hints) * targets @@ -462,8 +455,6 @@ def predict_full(model, input, fraction_with_hints=0.0, local_device=main_device result = predict(model, imt_set, local_device=local_device) result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1) - result.view(result.size(0), 4, -1)[:, :, 0] = boy_that_s_ugly - return result @@ -483,12 +474,10 @@ def batch_generation_imt(input): proba_erased = 1 - (1 - args.diffusion_proba_corruption) ** t mask_erased = (r <= proba_erased[:, None]).long() - noise = quiz_machine.problem.pure_noise(nb, input.device) - + noise = quiz_machine.pure_noise(nb, input.device) targets = input input = (1 - mask_erased) * input + mask_erased * noise masks = input.new_full(input.size(), 1) - masks.reshape(masks.size(0), 4, -1)[:, :, 0] = 0 return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) @@ -503,9 +492,8 @@ def prioritized_rand(low): def generate(model, nb, local_device=main_device): - all_input = quiz_machine.problem.pure_noise(nb, local_device) + all_input = quiz_machine.pure_noise(nb, local_device) all_masks = all_input.new_full(all_input.size(), 1) - all_masks.reshape(all_masks.size(0), 4, -1)[:, :, 0] = 0 for input, masks in tqdm.tqdm( zip( diff --git a/quiz_machine.py b/quiz_machine.py index dfedbf5..594b5ca 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -195,6 +195,11 @@ class QuizMachine: ###################################################################### + def pure_noise(self, nb, device): + r = self.problem.pure_noise(nb, device) + r = r.view(r.size(0), 4, -1)[:, :, 1:].reshape(r.size(0), -1) + return r + def quiz_set(self, nb_samples, c_quizzes, c_quiz_multiplier=1): if c_quizzes is None: quizzes = self.problem.generate_w_quizzes(nb_samples) @@ -222,7 +227,9 @@ class QuizMachine: i = torch.randperm(quizzes.size(0), device=quizzes.device) quizzes = quizzes[i].contiguous() - quizzes = quizzes.view(quizzes.size(0), 4, -1)[:, :, 1:].contiguous() + quizzes = quizzes.view(quizzes.size(0), 4, -1)[:, :, 1:].reshape( + quizzes.size(0), -1 + ) return quizzes -- 2.39.5