From e6448c688761d2cd128009d663411c8337018a54 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 31 Aug 2024 18:37:02 +0200 Subject: [PATCH] Update. --- main.py | 176 ++++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 138 insertions(+), 38 deletions(-) diff --git a/main.py b/main.py index 8726c96..c11f5c2 100755 --- a/main.py +++ b/main.py @@ -95,6 +95,8 @@ parser.add_argument("--gpus", type=str, default="all") parser.add_argument("--nb_models", type=int, default=5) +parser.add_argument("--nb_diffusion_iterations", type=int, default=25) + parser.add_argument("--min_succeed_to_validate", type=int, default=2) parser.add_argument("--max_fail_to_validate", type=int, default=3) @@ -309,6 +311,17 @@ log_string(f"vocabulary_size {vocabulary_size}") ###################################################################### + +def bag_len(bag): + return sum([x[0].size(0) for x in bag]) + + +def bag_to_tensors(bag): + return tuple(torch.cat([x[i] for x in bag], dim=0) for i in range(len(bag[0]))) + + +###################################################################### + # If we need to move an optimizer to a different device @@ -945,10 +958,6 @@ class FunctionalAE(nn.Module): ###################################################################### -nb_iterations = 25 -probs_iterations = 0.1 ** torch.linspace(0, 1, nb_iterations, device=main_device) -probs_iterations = probs_iterations[None, :] / probs_iterations.sum() - def ae_batches( quiz_machine, @@ -1024,6 +1033,8 @@ def ae_generate(model, input, mask_generate, noise_proba, nb_iterations_max=50): changed = True for it in range(nb_iterations_max): + print(f"{it=} {nb_iterations_max=}") + input_with_mask = NTC_channel_cat(input, mask_generate) logits = model(input_with_mask) dist = torch.distributions.categorical.Categorical(logits=logits) @@ -1044,7 +1055,8 @@ def ae_generate(model, input, mask_generate, noise_proba, nb_iterations_max=50): changed = changed & (update != input).max(dim=1).values input[changed] = update[changed] - log_string(f"remains {changed.long().sum()}") + if it == nb_iterations_max: + log_string(f"remains {changed.long().sum()}") return input @@ -1062,9 +1074,7 @@ def model_ae_proba_solutions(model, input): mask_generate = quiz_machine.make_quiz_mask( quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad ) - targets, logits = targets_and_prediction( - probs_iterations, model, q, mask_generate - ) + targets, logits = targets_and_prediction(model, q, mask_generate) loss_per_token = F.cross_entropy( logits.transpose(1, 2), targets, reduction="none" ) @@ -1076,6 +1086,9 @@ def model_ae_proba_solutions(model, input): return (-loss).exp() +nb_diffusion_iterations = 25 + + def degrade_input(input, mask_generate, nb_iterations, noise_proba): noise = torch.randint( quiz_machine.problem.nb_colors, input.size(), device=input.device @@ -1094,14 +1107,18 @@ def degrade_input(input, mask_generate, nb_iterations, noise_proba): return result -def targets_and_prediction(probs_iterations, model, input, mask_generate): +def targets_and_prediction(model, input, mask_generate): d = deterministic(mask_generate) - p = probs_iterations.expand(input.size(0), -1) - dist = torch.distributions.categorical.Categorical(probs=p) + probs_iterations = 0.1 ** torch.linspace( + 0, 1, args.nb_diffusion_iterations, device=input.device + ) + probs_iterations = probs_iterations[None, :] / probs_iterations.sum() + probs_iterations = probs_iterations.expand(input.size(0), -1) + dist = torch.distributions.categorical.Categorical(probs=probs_iterations) N0 = dist.sample() N1 = N0 + 1 N0 = (1 - d) * N0 - N1 = (1 - d) * N1 + d * nb_iterations + N1 = (1 - d) * N1 + d * args.nb_diffusion_iterations targets, input = degrade_input( input, mask_generate, (0 * N1, N1), noise_proba=noise_proba @@ -1113,7 +1130,7 @@ def targets_and_prediction(probs_iterations, model, input, mask_generate): return targets, logits -def run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=main_device): +def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device): with torch.autograd.no_grad(): model.eval().to(local_device) @@ -1128,9 +1145,7 @@ def run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=main_de local_device, "test", ): - targets, logits = targets_and_prediction( - probs_iterations, model, input, mask_generate - ) + targets, logits = targets_and_prediction(model, input, mask_generate) loss = NTC_masked_cross_entropy(logits, targets, mask_loss) acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0) @@ -1173,28 +1188,27 @@ def run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=main_de model.test_accuracy = nb_correct / nb_total - for f, record in [("prediction", record_d), ("generation", record_nd)]: - filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png" - result, predicted_parts, correct_parts = ( - torch.cat([x[i] for x in record])[:128] for i in [0, 1, 2] - ) + # for f, record in [("prediction", record_d), ("generation", record_nd)]: + # filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png" - l = [model_ae_proba_solutions(model, result) for model in other_models] - probas = torch.cat([x[:, None] for x in l], dim=1) - comments = [] + # result, predicted_parts, correct_parts = bag_to_tensors(record) - for l in probas: - comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) + # l = [model_ae_proba_solutions(model, result) for model in other_models] + # probas = torch.cat([x[:, None] for x in l], dim=1) + # comments = [] - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - filename, - quizzes=result, - predicted_parts=predicted_parts, - correct_parts=correct_parts, - comments=comments, - ) - log_string(f"wrote {filename}") + # for l in probas: + # comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) + + # quiz_machine.problem.save_quizzes_as_image( + # args.result_dir, + # filename, + # quizzes=result, + # predicted_parts=predicted_parts, + # correct_parts=correct_parts, + # comments=comments, + # ) + # log_string(f"wrote {filename}") # Prediction with functional perturbations @@ -1237,12 +1251,19 @@ def one_ae_epoch(model, other_models, quiz_machine, n_epoch, local_device=main_d local_device, "training", ): + input = input.to(local_device) + mask_generate = mask_generate.to(local_device) + mask_loss = mask_loss.to(local_device) + if nb_train_samples % args.batch_size == 0: model.optimizer.zero_grad() - targets, logits = targets_and_prediction( - probs_iterations, model, input, mask_generate + targets, logits = targets_and_prediction(model, input, mask_generate) + + print( + f"{input.device=} {logits.device=} {targets.device=} {logits.device=} {mask_loss.device=}" ) + loss = NTC_masked_cross_entropy(logits, targets, mask_loss) acc_train_loss += loss.item() * input.size(0) nb_train_samples += input.size(0) @@ -1256,7 +1277,7 @@ def one_ae_epoch(model, other_models, quiz_machine, n_epoch, local_device=main_d f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}" ) - run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=local_device) + # run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=local_device) ###################################################################### @@ -1288,6 +1309,83 @@ for i in range(args.nb_models): ###################################################################### + +def c_quiz_criterion_one_good_one_bad(probas): + return (probas.max(dim=1).values >= 0.8) & (probas.min(dim=1).values <= 0.2) + + +def c_quiz_criterion_diff(probas): + return (probas.max(dim=1).values - probas.min(dim=1).values) >= 0.5 + + +def c_quiz_criterion_two_certains(probas): + return ((probas >= 0.99).long().sum(dim=1) >= 2) & (probas.min(dim=1).values <= 0.5) + + +def generate_ae_c_quizzes(models, local_device=main_device): + criteria = [ + c_quiz_criterion_one_good_one_bad, + c_quiz_criterion_diff, + c_quiz_criterion_two_certains, + ] + + for m in models: + m.eval().to(local_device) + + quad_order = ("A", "f_A", "B", "f_B") + + template = quiz_machine.problem.create_empty_quizzes( + nb=args.batch_size, quad_order=quad_order + ).to(local_device) + + mask_generate = quiz_machine.make_quiz_mask( + quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1) + ) + + records = [[] for _ in criteria] + + with torch.autograd.no_grad(): + while min([bag_len(bag) for bag in records]) < 128: + bl = [bag_len(bag) for bag in records] + log_string(f"bag_len {bl}") + + model = models[torch.randint(len(models), (1,)).item()] + result = ae_generate(model, template, mask_generate, 0.0) + + probas = torch.cat( + [model_ae_proba_solutions(model, result)[:, None] for model in models], + dim=1, + ) + for c, r in zip(criteria, records): + q = result[c(probas)] + if q.size(0) > 0: + r.append(q) + + # for f, record in [("prediction", record_d), ("generation", record_nd)]: + # filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png" + + # result, predicted_parts, correct_parts = bag_to_tensors(record) + + # l = [model_ae_proba_solutions(model, result) for model in models] + # probas = torch.cat([x[:, None] for x in l], dim=1) + # comments = [] + + # for l in probas: + # comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) + + # quiz_machine.problem.save_quizzes_as_image( + # args.result_dir, + # filename, + # quizzes=result, + # predicted_parts=predicted_parts, + # correct_parts=correct_parts, + # comments=comments, + # ) + # log_string(f"wrote {filename}") + + +###################################################################### + current_epoch = 0 if args.resume: @@ -1374,6 +1472,8 @@ for n_epoch in range(current_epoch, args.nb_epochs): for t in threads: t.join() + generate_ae_c_quizzes(models, local_device=main_device) + # -------------------------------------------------------------------- for model in models: -- 2.39.5