From a5e817de25140f2a16e02ac4d2b27f45af9679b0 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 15 Sep 2024 22:02:09 +0200 Subject: [PATCH] Update. --- diffusion.py | 43 ++++++++++++++++--------------------------- main.py | 51 ++++++++++++++++++++++++++++++--------------------- 2 files changed, 46 insertions(+), 48 deletions(-) diff --git a/diffusion.py b/diffusion.py index 2dc5861..abe8986 100755 --- a/diffusion.py +++ b/diffusion.py @@ -52,17 +52,18 @@ class Diffuser: ###################################################################### - def make_mask_hints(self, mask_generate, nb_hints): - if nb_hints == 0: + def make_mask_hints(mask_generate, nb_hints): + if nb_hints is None: mask_hints = None else: u = ( torch.rand(mask_generate.size(), device=mask_generate.device) * mask_generate ) - mask_hints = ( - u > u.sort(dim=1, descending=True).values[:, nb_hints, None] - ).long() + v = u.sort(dim=1, descending=True).values.gather( + dim=1, index=nb_hints[:, None] + ) + mask_hints = (u > v).long() return mask_hints @@ -71,7 +72,7 @@ class Diffuser: # logits starting from a x_t|X_0=x_0 picked at random with t random def logits_hat_x_0_from_random_iteration( - self, model, x_0, mask_generate, nb_hints=0, prompt_noise=0.0 + self, model, x_0, mask_generate, nb_hints=None, prompt_noise=0.0 ): noise = self.mu_T_sampler(x_0.size(), device=x_0.device) @@ -79,12 +80,7 @@ class Diffuser: mask_generate.sum(dim=1) < mask_generate.size(1) // 2 ).long()[:, None] - mask_hints = self.make_mask_hints(mask_generate, nb_hints) - - if mask_hints is None: - mask_start = mask_generate - else: - mask_start = mask_generate * (1 - mask_hints) + mask_hints = self.make_mask_hints(mask_generate, nb_hints) * single_iteration # We favor iterations near the clean signal @@ -98,13 +94,9 @@ class Diffuser: t = dist.sample() + 1 - x_t = single_iteration * noise + ( - 1 - single_iteration - ) * self.sample_x_t_given_x_0(x_0, t) - - # Only the part to generate is degraded, the rest is a perfect - # noise-free conditionning - + x_T_with_hints = mask_hints * x_0 + (1 - mask_hint) * noise + x_t = self.sample_x_t_given_x_0(x_0, t) + x_t = single_iteration * x_T_with_hints + (1 - single_iteration) * x_t x_t = (1 - mask_generate) * x_0 + mask_generate * x_t # We may inject noise to prevent high-complexity non-structure @@ -128,7 +120,7 @@ class Diffuser: ###################################################################### - def generate(self, model, x_0, mask_generate, nb_hints=0): + def generate(self, model, x_0, mask_generate, nb_hints=None): noise = self.mu_T_sampler(x_0.size(), device=x_0.device) single_iteration = ( @@ -137,12 +129,10 @@ class Diffuser: mask_hints = self.make_mask_hints(mask_generate, nb_hints) - if mask_hints is None: - mask_start = mask_generate - else: - mask_start = mask_generate * (1 - mask_hints) - - x_t = (1 - mask_start) * x_0 + mask_start * noise + x_T_with_hints = mask_hints * x_0 + (1 - mask_hint) * noise + x_t = self.sample_x_t_given_x_0(x_0, t) + x_t = single_iteration * x_T_with_hints + (1 - single_iteration) * x_t + x_t = (1 - mask_generate) * x_0 + mask_generate * x_t changed = True @@ -150,7 +140,6 @@ class Diffuser: x_t_with_mask = NTC_channel_cat(x_t, mask_generate) with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): logits = model(x_t_with_mask) - # logits[:, :, quiz_machine.problem.nb_colors :] = float("-inf") dist = torch.distributions.categorical.Categorical(logits=logits) hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample() diff --git a/main.py b/main.py index d508c97..0d46aa2 100755 --- a/main.py +++ b/main.py @@ -405,7 +405,10 @@ def model_proba_solutions(model, input, log_probas=False, reduce=True): quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad ) logits = logits_hat_x_0_from_random_iteration( - model, x_0, mask_generate, prompt_noise=args.prompt_noise + model=model, + x_0=x_0, + mask_generate=mask_generate, + prompt_noise=args.prompt_noise, ) loss_per_token = F.cross_entropy( logits.transpose(1, 2), x_0, reduction="none" @@ -543,21 +546,20 @@ def run_test( # Save some images - if n_epoch < 100: - for f, record in [("prediction", record_d), ("generation", record_nd)]: - result, predicted_parts, correct_parts = bag_to_tensors(record) + for f, record in [("prediction", record_d), ("generation", record_nd)]: + result, predicted_parts, correct_parts = bag_to_tensors(record) - filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png" + filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png" - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - filename, - quizzes=result[:128], - predicted_parts=predicted_parts[:128], - correct_parts=correct_parts[:128], - ) + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + filename, + quizzes=result[:128], + predicted_parts=predicted_parts[:128], + correct_parts=correct_parts[:128], + ) - log_string(f"wrote {filename}") + log_string(f"wrote {filename}") return nb_correct / nb_total @@ -587,12 +589,15 @@ def one_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_device) if nb_train_samples % args.batch_size == 0: model.optimizer.zero_grad() + nb_hints = torch.randint(2, (x_0.size(0),), device=x_0.device) * args.nb_hints + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): logits = diffuser.logits_hat_x_0_from_random_iteration( model=model, x_0=x_0, mask_generate=mask_generate, prompt_noise=args.prompt_noise, + nb_hints=nb_hints, ) loss = NTC_masked_cross_entropy(logits, x_0, mask_generate) @@ -669,7 +674,7 @@ def quiz_validation( nb_have_to_be_correct, nb_have_to_be_wrong, nb_mistakes_to_be_wrong, - nb_hints=0, + nb_hints, nb_runs=1, ): ###################################################################### @@ -677,7 +682,10 @@ def quiz_validation( if c_quizzes.size(0) > args.inference_batch_size: record = [] - for q in c_quizzes.split(args.inference_batch_size): + for q, nh in zip( + c_quizzes.split(args.inference_batch_size), + nb_hints.split(args.inference_batch_size), + ): record.append( quiz_validation( models=models, @@ -686,7 +694,7 @@ def quiz_validation( nb_have_to_be_correct=nb_have_to_be_correct, nb_have_to_be_wrong=nb_have_to_be_wrong, nb_mistakes_to_be_wrong=nb_mistakes_to_be_wrong, - nb_hints=nb_hints, + nb_hints=nh, nb_runs=nb_runs, ) ) @@ -732,9 +740,6 @@ def quiz_validation( nb_correct += correct.long() nb_wrong += wrong.long() - # log_string(f"{nb_hints=} {nb_correct=}") - # log_string(f"{nb_hints=} {nb_wrong=}") - to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong) wrong = torch.cat(record_wrong, dim=1) @@ -780,6 +785,10 @@ def generate_c_quizzes(models, nb, local_device=main_device): to_keep = quiz_machine.problem.trivial(c_quizzes) == False c_quizzes = c_quizzes[to_keep] + nb_hints = torch.full( + (c_quizzes.size(0),), args.nb_hints, device=c_quizzes.device + ) + if c_quizzes.size(0) > 0: to_keep, nb_correct, nb_wrong, record_wrong = quiz_validation( models, @@ -788,7 +797,7 @@ def generate_c_quizzes(models, nb, local_device=main_device): nb_have_to_be_correct=args.nb_have_to_be_correct, nb_have_to_be_wrong=args.nb_have_to_be_wrong, nb_mistakes_to_be_wrong=args.nb_mistakes_to_be_wrong, - nb_hints=args.nb_hints, + nb_hints=nb_hints, nb_runs=args.nb_runs, ) @@ -848,7 +857,7 @@ def save_c_quizzes_with_scores(models, c_quizzes, filename, solvable_only=False) nb_have_to_be_correct=args.nb_have_to_be_correct, nb_have_to_be_wrong=0, nb_mistakes_to_be_wrong=args.nb_mistakes_to_be_wrong, - nb_hints=0, + nb_hints=None, ) if solvable_only: -- 2.39.5