From 73b65458082c81d7d23db242da691c9cfc7b1400 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 5 Sep 2024 08:28:26 +0200 Subject: [PATCH] Update. --- main.py | 114 +++++++++++++++++++++++++++++--------------------------- 1 file changed, 59 insertions(+), 55 deletions(-) diff --git a/main.py b/main.py index 02c9fc6..934940e 100755 --- a/main.py +++ b/main.py @@ -57,11 +57,13 @@ parser.add_argument("--nb_train_samples", type=int, default=25000) parser.add_argument("--nb_test_samples", type=int, default=1000) +parser.add_argument("--nb_c_quizzes", type=int, default=2500) + parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None) parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None) -parser.add_argument("--c_quiz_multiplier", type=int, default=10) +parser.add_argument("--c_quiz_multiplier", type=int, default=1) parser.add_argument("--learning_rate", type=float, default=5e-4) @@ -115,7 +117,7 @@ parser.add_argument("--temperature_hot", type=float, default=1.5) parser.add_argument("--temperature_cold", type=float, default=1) -parser.add_argument("--prompt_noise", type=float, default=0.0) +parser.add_argument("--prompt_noise", type=float, default=0.05) parser.add_argument("--dirty_debug", action="store_true", default=False) @@ -754,10 +756,10 @@ def deterministic(mask_generate): return (mask_generate.sum(dim=1) < mask_generate.size(1) // 2).long() -# This function returns a tensor of same shape as low, full of uniform -# random values in [0,1], such that the values corresponding to the -# True in low are all lesser than the values corresponding to the -# False. +# This function returns a 2d tensor of same shape as low, full of +# uniform random values in [0,1], such that, in every row, the values +# corresponding to the True in low are all lesser than the values +# corresponding to the False. def prioritized_rand(low): @@ -840,7 +842,7 @@ def model_ae_proba_solutions(model, input, log_proba=False): nb_diffusion_iterations = 25 -def degrade_input_to_generate(input, mask_generate, nb_iterations): +def degrade_input_to_generate(input, mask_generate, steps_nb_iterations): noise = torch.randint( quiz_machine.problem.nb_colors, input.size(), device=input.device ) @@ -849,7 +851,7 @@ def degrade_input_to_generate(input, mask_generate, nb_iterations): result = [] - for n in nb_iterations: + for n in steps_nb_iterations: proba_erased = 1 - (1 - args.diffusion_noise_proba) ** n mask_erased = mask_generate * (r <= proba_erased[:, None]).long() x = (1 - mask_erased) * input + mask_erased * noise @@ -929,8 +931,8 @@ def run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_ args.nb_test_samples, data_structures, local_device, - c_quizzes, - "test", + c_quizzes=c_quizzes, + desc="test", ): targets = input.clone() result = ae_generate( @@ -1080,6 +1082,39 @@ for i in range(args.nb_models): ###################################################################### +def save_badness_statistics( + n_epoch, models, c_quizzes, suffix=None, local_device=main_device +): + for model in models: + model.eval().to(local_device) + c_quizzes = c_quizzes.to(local_device) + with torch.autograd.no_grad(): + log_probas = sum( + [model_ae_proba_solutions(model, c_quizzes) for model in models] + ) + i = log_probas.sort().indices + + suffix = "" if suffix is None else "_" + suffix + + filename = f"culture_badness_{n_epoch:04d}{suffix}.png" + + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + filename, + quizzes=c_quizzes[i[:128]], + # predicted_parts=predicted_parts, + # correct_parts=correct_parts, + # comments=comments, + delta=True, + nrow=8, + ) + + log_string(f"wrote {filename}") + + +###################################################################### + + def c_quiz_criterion_one_good_one_bad(probas): return (probas.max(dim=1).values >= 0.75) & (probas.min(dim=1).values <= 0.25) @@ -1101,9 +1136,9 @@ def c_quiz_criterion_diff2(probas): return (v[:, -2] - v[:, 0]) >= 0.5 -def c_quiz_criterion_only_one(probas): +def c_quiz_criterion_few_good_one_bad(probas): v = probas.sort(dim=1).values - return (v[:, -1] >= 0.75) & (v[:, -2] <= 0.25) + return (v[:, 0] <= 0.25) & (v[:, -3] >= 0.5) def c_quiz_criterion_two_good(probas): @@ -1116,40 +1151,11 @@ def c_quiz_criterion_some(probas): ) -def save_badness_statistics( - n_epoch, models, c_quizzes, suffix=None, local_device=main_device -): - for model in models: - model.eval().to(local_device) - c_quizzes = c_quizzes.to(local_device) - with torch.autograd.no_grad(): - log_probas = sum( - [model_ae_proba_solutions(model, c_quizzes) for model in models] - ) - i = log_probas.sort().indices - - suffix = "" if suffix is None else "_" + suffix - - filename = f"culture_badness_{n_epoch:04d}{suffix}.png" - - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - filename, - quizzes=c_quizzes[i[:128]], - # predicted_parts=predicted_parts, - # correct_parts=correct_parts, - # comments=comments, - delta=True, - nrow=8, - ) - - log_string(f"wrote {filename}") - - def generate_ae_c_quizzes(models, nb, local_device=main_device): criteria = [ + c_quiz_criterion_few_good_one_bad, # c_quiz_criterion_only_one, - c_quiz_criterion_one_good_one_bad, + # c_quiz_criterion_one_good_one_bad, # c_quiz_criterion_one_good_no_very_bad, # c_quiz_criterion_diff, # c_quiz_criterion_diff2, @@ -1186,22 +1192,22 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): and min([bag_len(bag) for bag in records]) < wanted_nb ): model = models[torch.randint(len(models), (1,)).item()] - result = ae_generate(model, template, mask_generate) + c_quizzes = ae_generate(model, template, mask_generate) - to_keep = quiz_machine.problem.trivial(result) == False - result = result[to_keep] + to_keep = quiz_machine.problem.trivial(c_quizzes) == False + c_quizzes = c_quizzes[to_keep] - if result.size(0) > 0: + if c_quizzes.size(0) > 0: probas = torch.cat( [ - model_ae_proba_solutions(model, result)[:, None] + model_ae_proba_solutions(model, c_quizzes)[:, None] for model in models ], dim=1, ) for c, r in zip(criteria, records): - q = result[c(probas)] + q = c_quizzes[c(probas)] if q.size(0) > 0: r.append(q) @@ -1234,7 +1240,7 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): quizzes = torch.cat(u, dim=0)[:nb_to_save] filename = f"culture_c_quiz_{n_epoch:04d}_{n:02d}.png" - # result, predicted_parts, correct_parts = bag_to_tensors(record) + # c_quizzes, predicted_parts, correct_parts = bag_to_tensors(record) l = [model_ae_proba_solutions(model, quizzes) for model in models] probas = torch.cat([x[:, None] for x in l], dim=1) @@ -1351,9 +1357,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): last_n_epoch_c_quizzes = n_epoch nb_gpus = len(gpus) - nb_c_quizzes_to_generate = ( - args.nb_train_samples // args.c_quiz_multiplier + nb_gpus - 1 - ) // nb_gpus + nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus # -------------------------------------------------------------------- @@ -1376,7 +1380,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): for t in threads: t.join() - time_c_quizzes = time.perf_counter() - start_time + time_c_quizzes = int(time.perf_counter() - start_time) c_quizzes = torch.cat([q.to(main_device) for q in records], dim=0) @@ -1420,7 +1424,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): for t in threads: t.join() - time_train += time.perf_counter() - start_time + time_train += int(time.perf_counter() - start_time) # -------------------------------------------------------------------- -- 2.39.5