From 0bc23d875cd730ea1488689b1964db4da2f7442a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 3 Sep 2024 08:51:14 +0200 Subject: [PATCH] Update. --- main.py | 45 ++++++++++++++++++++++++++++++--------------- quiz_machine.py | 14 -------------- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/main.py b/main.py index 4860073..9b2282f 100755 --- a/main.py +++ b/main.py @@ -113,7 +113,7 @@ parser.add_argument("--temperature_hot", type=float, default=1.5) parser.add_argument("--temperature_cold", type=float, default=1) -parser.add_argument("--prompt_noise", type=float, default=0.05) +parser.add_argument("--prompt_noise", type=float, default=0.0) parser.add_argument("--dirty_debug", action="store_true", default=False) @@ -298,7 +298,6 @@ quiz_machine = quiz_machine.QuizMachine( problem=problem, batch_size=args.inference_batch_size, result_dir=args.result_dir, - prompt_noise=args.prompt_noise, logger=log_string, device=main_device, ) @@ -1098,7 +1097,7 @@ def model_ae_proba_solutions(model, input, log_proba=False): nb_diffusion_iterations = 25 -def degrade_input(input, mask_generate, nb_iterations): +def degrade_input_to_generate(input, mask_generate, nb_iterations): noise = torch.randint( quiz_machine.problem.nb_colors, input.size(), device=input.device ) @@ -1116,20 +1115,36 @@ def degrade_input(input, mask_generate, nb_iterations): return result -def targets_and_prediction(model, input, mask_generate): +def targets_and_prediction(model, input, mask_generate, prompt_noise=0.0): d = deterministic(mask_generate) + probs_iterations = 0.1 ** torch.linspace( 0, 1, args.nb_diffusion_iterations, device=input.device ) + probs_iterations = probs_iterations[None, :] / probs_iterations.sum() probs_iterations = probs_iterations.expand(input.size(0), -1) dist = torch.distributions.categorical.Categorical(probs=probs_iterations) - N0 = dist.sample() - N1 = N0 + 1 - N0 = (1 - d) * N0 - N1 = (1 - d) * N1 + d * args.nb_diffusion_iterations - targets, input = degrade_input(input, mask_generate, (0 * N1, N1)) + # N0 = dist.sample() + # N1 = N0 + 1 + # N0 = (1 - d) * N0 + # N1 = (1 - d) * N1 + d * args.nb_diffusion_iterations + + N0 = input.new_zeros(input.size(0)) + N1 = dist.sample() + 1 + + targets, input = degrade_input_to_generate(input, mask_generate, (N0, N1)) + + if prompt_noise > 0: + mask_prompt_noise = ( + torch.rand(input.size(), device=input.device()) <= prompt_noise + ).long() + noise = torch.randint( + quiz_machine.problem.nb_colors, input.size(), device=input.device + ) + noisy_input = (1 - mask_prompt_noise) * input + mask_prompt_noise * noise + input = (1 - mask_generate) * noisy_input + mask_generate * input input_with_mask = NTC_channel_cat(input, mask_generate) logits = model(input_with_mask) @@ -1250,9 +1265,7 @@ def run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_ ###################################################################### -def one_ae_epoch( - model, other_models, quiz_machine, n_epoch, c_quizzes, local_device=main_device -): +def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_device): model.train().to(local_device) optimizer_to(model.optimizer, local_device) @@ -1273,7 +1286,9 @@ def one_ae_epoch( if nb_train_samples % args.batch_size == 0: model.optimizer.zero_grad() - targets, logits = targets_and_prediction(model, input, mask_generate) + targets, logits = targets_and_prediction( + model, input, mask_generate, prompt_noise=args.prompt_noise + ) loss = NTC_masked_cross_entropy(logits, targets, mask_loss) acc_train_loss += loss.item() * input.size(0) @@ -1572,7 +1587,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): # -------------------------------------------------------------------- - # one_ae_epoch(models[0], models, quiz_machine, n_epoch, main_device) + # one_ae_epoch(models[0], quiz_machine, n_epoch, main_device) # exit(0) log_string(f"{time_train=} {time_c_quizzes=}") @@ -1612,7 +1627,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): t = threading.Thread( target=one_ae_epoch, daemon=True, - args=(model, models, quiz_machine, n_epoch, c_quizzes, gpu), + args=(model, quiz_machine, n_epoch, c_quizzes, gpu), ) threads.append(t) diff --git a/quiz_machine.py b/quiz_machine.py index ce4d4f5..f1eb9db 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -67,7 +67,6 @@ class QuizMachine: problem, batch_size, result_dir, - prompt_noise, logger, device=torch.device("cpu"), ): @@ -79,7 +78,6 @@ class QuizMachine: self.logger = logger self.prompt_len = None self.answer_len = None - self.prompt_noise = prompt_noise # quad_order, quad_generate, quad_noise, quad_loss self.train_structures = [ @@ -186,13 +184,6 @@ class QuizMachine: quad_order, quad_generate, quad_noise, quad_loss = s i = order_ids == j quizzes[i] = self.problem.reconfigure(quizzes[i], quad_order=quad_order) - if self.prompt_noise > 0.0: - quizzes[i] = self.problem.inject_noise( - quizzes[i], - self.prompt_noise, - quad_order=quad_order, - quad_noise=quad_noise, - ) quiz_mask_generate[i] = self.make_quiz_mask( quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_generate ) @@ -335,11 +326,6 @@ class QuizMachine: device=device, ) - # if self.prompt_noise > 0.0 and quad_noise is not None: - # c_quizzes = self.problem.inject_noise( - # c_quizzes, self.prompt_noise, quad_order=quad_order, quad_noise=quad_noise - # ) - with torch.autograd.no_grad(): t = model.training model.eval() -- 2.39.5