From e362df7536b958e8a1ea5645d0beb269680c94c4 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 6 Sep 2024 12:16:21 +0200 Subject: [PATCH] Update. --- main.py | 188 ++++++++++++++++++++++---------------------------------- 1 file changed, 74 insertions(+), 114 deletions(-) diff --git a/main.py b/main.py index d1a1c8f..1e398e8 100755 --- a/main.py +++ b/main.py @@ -732,7 +732,7 @@ def ae_batches( ): c_quiz_bags = [] if c_quizzes is None else [c_quizzes.to("cpu")] - full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input( + full_input, full_mask_generate, _ = quiz_machine.data_input( nb, c_quiz_bags, data_structures=data_structures, @@ -742,7 +742,6 @@ def ae_batches( src = zip( full_input.split(batch_size), full_mask_generate.split(batch_size), - full_mask_loss.split(batch_size), ) if desc is not None: @@ -753,11 +752,10 @@ def ae_batches( total=full_input.size(0) // batch_size, ) - for input, mask_generate, mask_loss in src: + for input, mask_generate in src: yield ( input.to(local_device), mask_generate.to(local_device), - mask_loss.to(local_device), ) @@ -777,23 +775,23 @@ def deterministic(mask_generate): ###################################################################### # -# Given x_0 and t_0, t_1, ..., returns x_{t_0}, x_{t_1}, with +# Given x_0 and t_0, t_1, ..., returns # -# x_{t_k} ~ P(X_{t_k} | X_0=x_0) +# x_{t_0}, ..., x_{t_K} ~ P(X_{t_0}, ..., X_{t_K} | X_0=x_0) # -def degrade_input_to_generate(x0, mask_generate, steps_nb_iterations): - noise = torch.randint(quiz_machine.problem.nb_colors, x0.size(), device=x0.device) +def degrade_input_to_generate(x_0, steps_nb_iterations): + noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device) - r = torch.rand(mask_generate.size(), device=mask_generate.device) + r = torch.rand(x_0.size(), device=x_0.device) result = [] for n in steps_nb_iterations: proba_erased = 1 - (1 - args.diffusion_noise_proba) ** n - mask_erased = mask_generate * (r <= proba_erased[:, None]).long() - x = (1 - mask_erased) * x0 + mask_erased * noise + mask_erased = (r <= proba_erased[:, None]).long() + x = (1 - mask_erased) * x_0 + mask_erased * noise result.append(x) return result @@ -801,44 +799,45 @@ def degrade_input_to_generate(x0, mask_generate, steps_nb_iterations): ###################################################################### -# Given x_t and a mas - -def targets_and_logits(model, input, mask_generate, prompt_noise=0.0): - d = deterministic(mask_generate) +def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise=0.0): + # We favor iterations near the clean signal probs_iterations = 0.1 ** torch.linspace( - 0, 1, args.nb_diffusion_iterations, device=input.device + 0, 1, args.nb_diffusion_iterations, device=x_0.device ) probs_iterations = probs_iterations[None, :] / probs_iterations.sum() - probs_iterations = probs_iterations.expand(input.size(0), -1) + probs_iterations = probs_iterations.expand(x_0.size(0), -1) dist = torch.distributions.categorical.Categorical(probs=probs_iterations) - # N0 = dist.sample() - # N1 = N0 + 1 - # N0 = (1 - d) * N0 - # N1 = (1 - d) * N1 + d * args.nb_diffusion_iterations - - N0 = input.new_zeros(input.size(0)) N1 = dist.sample() + 1 - targets, input = degrade_input_to_generate(input, mask_generate, (N0, N1)) + (x_t,) = degrade_input_to_generate(x_0, (N1,)) + + # Only the part to generate is degraded, the rest is a perfect + # noise-free conditionning + + x_t = (1 - mask_generate) * x_0 + mask_generate * x_t + + # We may inject noise to prevent high-complexity non-structure + # signal to be generated as a way of "increasing reasoning + # complexity" if prompt_noise > 0: mask_prompt_noise = ( - torch.rand(input.size(), device=input.device) <= prompt_noise + torch.rand(x_t.size(), device=x_t.device) <= prompt_noise ).long() noise = torch.randint( - quiz_machine.problem.nb_colors, input.size(), device=input.device + quiz_machine.problem.nb_colors, x_t.size(), device=x_t.device ) - noisy_input = (1 - mask_prompt_noise) * input + mask_prompt_noise * noise - input = (1 - mask_generate) * noisy_input + mask_generate * input + noisy_x_t = (1 - mask_prompt_noise) * x_t + mask_prompt_noise * noise + x_t = (1 - mask_generate) * noisy_x_t + mask_generate * x_t - input_with_mask = NTC_channel_cat(input, mask_generate) - logits = model(input_with_mask) + x_t_with_mask = NTC_channel_cat(x_t, mask_generate) + logits_hat_x_0 = model(x_t_with_mask) - return targets, logits + return logits_hat_x_0 ###################################################################### @@ -858,42 +857,38 @@ def prioritized_rand(low): return y -def ae_generate(model, input, mask_generate, nb_iterations_max=50): - noise = torch.randint( - quiz_machine.problem.nb_colors, input.size(), device=input.device - ) +def ae_generate(model, x_0, mask_generate, nb_iterations_max=50): + noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device) - input = (1 - mask_generate) * input + mask_generate * noise + x_t = (1 - mask_generate) * x_0 + mask_generate * noise - d = deterministic(mask_generate)[:, None] + one_iteration_prediction = deterministic(mask_generate)[:, None] changed = True for it in range(nb_iterations_max): - input_with_mask = NTC_channel_cat(input, mask_generate) - logits = model(input_with_mask) + x_t_with_mask = NTC_channel_cat(x_t, mask_generate) + logits = model(x_t_with_mask) dist = torch.distributions.categorical.Categorical(logits=logits) - final = dist.sample() - r = prioritized_rand(final != input) + hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample() - mask_erased = mask_generate * (r <= args.diffusion_noise_proba).long() + r = prioritized_rand(hat_x_0 != x_t) - mask_to_change = d * mask_generate + (1 - d) * mask_erased + mask_changes = (r <= args.diffusion_noise_proba).long() - update = (1 - mask_to_change) * input + mask_to_change * final + hat_x_t_minus_1 = one_iteration_prediction * hat_x_0 + ( + 1 - one_iteration_prediction + ) * ((1 - mask_changes) * x_t + mask_changes * hat_x_0) - if update.equal(input): + if hat_x_t_minus_1.equal(x_t): # log_string(f"exit after {it+1} iterations") break else: - changed = changed & (update != input).max(dim=1).values - input[changed] = update[changed] - - # if it == nb_iterations_max: - # log_string(f"remains {changed.long().sum()}") + changed = changed & (hat_x_t_minus_1 != x_t).max(dim=1).values + x_t[changed] = hat_x_t_minus_1[changed] - return input + return x_t ###################################################################### @@ -902,18 +897,18 @@ def ae_generate(model, input, mask_generate, nb_iterations_max=50): def model_ae_proba_solutions(model, input, log_proba=False): record = [] - for q in input.split(args.batch_size): + for x_0 in input.split(args.batch_size): loss = 0 for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]: mask_generate = quiz_machine.make_quiz_mask( quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad ) - targets, logits = targets_and_logits( - model, q, mask_generate, prompt_noise=args.prompt_noise + logits = logits_hat_x_0_from_random_iteration( + model, x_0, mask_generate, prompt_noise=args.prompt_noise ) loss_per_token = F.cross_entropy( - logits.transpose(1, 2), targets, reduction="none" + logits.transpose(1, 2), x_0, reduction="none" ) loss += (loss_per_token * mask_generate).sum(dim=1) record.append(loss) @@ -929,20 +924,20 @@ def model_ae_proba_solutions(model, input, log_proba=False): def model_ae_argmax_nb_disagreements(model, input): record = [] - for q in input.split(args.batch_size): + for x_0 in input.split(args.batch_size): nb_disagreements = 0 for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]: mask_generate = quiz_machine.make_quiz_mask( quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad ) - targets, logits = targets_and_logits( - model, q, mask_generate, prompt_noise=args.prompt_noise + logits = logits_hat_x_0_from_random_iteration( + model, x_0, mask_generate, prompt_noise=args.prompt_noise ) predicted = logits.argmax(dim=-1) nb_disagreements = nb_disagreements + ( - mask_generate * predicted != mask_generate * targets + mask_generate * predicted != mask_generate * x_0 ).long().sum(dim=1) record.append(nb_disagreements) @@ -957,18 +952,18 @@ def model_ae_argmax_predictions(model, input): result = input.clone() # result[...] = 0 - for r, q in zip(result.split(args.batch_size), input.split(args.batch_size)): + for r, x_0 in zip(result.split(args.batch_size), input.split(args.batch_size)): for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]: mask_generate = quiz_machine.make_quiz_mask( quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad ) - targets, logits = targets_and_logits( - model, q, mask_generate, prompt_noise=args.prompt_noise + logits = logits_hat_x_0_from_random_iteration( + model, x_0, mask_generate, prompt_noise=args.prompt_noise ) - predicted = logits.argmax(dim=-1) + hat_x_0 = logits.argmax(dim=-1) - r[...] = (1 - mask_generate) * r + mask_generate * predicted + r[...] = (1 - mask_generate) * r + mask_generate * hat_x_0 return result @@ -991,7 +986,7 @@ def run_ae_test( nb_test_samples, acc_test_loss = 0, 0.0 - for input, mask_generate, mask_loss in ae_batches( + for x_0, mask_generate in ae_batches( quiz_machine, args.nb_test_samples, data_structures, @@ -999,10 +994,10 @@ def run_ae_test( c_quizzes=c_quizzes, desc="test", ): - targets, logits = targets_and_logits(model, input, mask_generate) - loss = NTC_masked_cross_entropy(logits, targets, mask_loss) - acc_test_loss += loss.item() * input.size(0) - nb_test_samples += input.size(0) + logits = logits_hat_x_0_from_random_iteration(model, x_0, mask_generate) + loss = NTC_masked_cross_entropy(logits, x_0, mask_generate) + acc_test_loss += loss.item() * x_0.size(0) + nb_test_samples += x_0.size(0) log_string( f"{prefix}test_loss {n_epoch} model {model.id} {acc_test_loss/nb_test_samples}" @@ -1012,7 +1007,7 @@ def run_ae_test( nb_correct, nb_total, record_d, record_nd = 0, 0, [], [] - for input, mask_generate, mask_loss in ae_batches( + for x_0, mask_generate in ae_batches( quiz_machine, args.nb_test_samples, data_structures, @@ -1020,13 +1015,12 @@ def run_ae_test( c_quizzes=c_quizzes, desc="test", ): - targets = input.clone() result = ae_generate( model, - (1 - mask_generate) * input, + (1 - mask_generate) * x_0, mask_generate, ) - correct = (result == targets).min(dim=1).values.long() + correct = (result == x_0).min(dim=1).values.long() predicted_parts = mask_generate.reshape(mask_generate.size(0), 4, -1)[ :, :, 1 ] @@ -1052,49 +1046,16 @@ def run_ae_test( result, predicted_parts, correct_parts = bag_to_tensors(record) - # l = [model_ae_proba_solutions(model, result) for model in other_models] - # probas = torch.cat([x[:, None] for x in l], dim=1) - # comments = [] - - # for l in probas: - # comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) - quiz_machine.problem.save_quizzes_as_image( args.result_dir, filename, quizzes=result[:128], predicted_parts=predicted_parts[:128], correct_parts=correct_parts[:128], - # comments=comments, ) log_string(f"wrote {filename}") - # Prediction with functional perturbations - - # input, mask_generate, mask_loss = next( - # ae_batches( - # quiz_machine, - # [ - # ( - # ("A", "f_A", "B", "f_B"), - # (0, 0, 0, 1), - # (0, 0, 1, 0), - # (0, 0, 0, 1), - # ), - # ], - # local_device, - # desc=None, - # ) - # ) - # targets = input.clone() - # p = torch.rand(4,model.f_tokens.size(1)).sort(dim=1).indices - # def change_theta(theta_A, theta_B): - # theta - # result = ae_generate( - # model, (1 - mask_generate) * input, mask_generate - # ) - ###################################################################### @@ -1105,7 +1066,7 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi nb_train_samples, acc_train_loss = 0, 0.0 - for input, mask_generate, mask_loss in ae_batches( + for x_0, mask_generate in ae_batches( quiz_machine, args.nb_train_samples, data_structures, @@ -1113,20 +1074,19 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi c_quizzes, "training", ): - input = input.to(local_device) + x_0 = x_0.to(local_device) mask_generate = mask_generate.to(local_device) - mask_loss = mask_loss.to(local_device) if nb_train_samples % args.batch_size == 0: model.optimizer.zero_grad() - targets, logits = targets_and_logits( - model, input, mask_generate, prompt_noise=args.prompt_noise + logits = logits_hat_x_0_from_random_iteration( + model, x_0, mask_generate, prompt_noise=args.prompt_noise ) - loss = NTC_masked_cross_entropy(logits, targets, mask_loss) - acc_train_loss += loss.item() * input.size(0) - nb_train_samples += input.size(0) + loss = NTC_masked_cross_entropy(logits, x_0, mask_generate) + acc_train_loss += loss.item() * x_0.size(0) + nb_train_samples += x_0.size(0) loss.backward() -- 2.39.5