From 72fe1655f90e94bbb534f875d998e6cc18e648a9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 5 Sep 2024 11:50:34 +0200 Subject: [PATCH] Update. --- main.py | 138 ++++++++++++++++++++++++++++++++------------------------ 1 file changed, 80 insertions(+), 58 deletions(-) diff --git a/main.py b/main.py index 934940e..e95c4f6 100755 --- a/main.py +++ b/main.py @@ -839,7 +839,28 @@ def model_ae_proba_solutions(model, input, log_proba=False): return (-loss).exp() -nb_diffusion_iterations = 25 +def model_ae_argmax_nb_disagreements(model, input): + record = [] + + for q in input.split(args.batch_size): + nb_disagreements = 0 + for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]: + mask_generate = quiz_machine.make_quiz_mask( + quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad + ) + targets, logits = targets_and_prediction( + model, q, mask_generate, prompt_noise=args.prompt_noise + ) + + predicted = logits.argmax(dim=-1) + + nb_disagreements = nb_disagreements + ( + mask_generate * predicted != mask_generate * targets + ).long().sum(dim=1) + + record.append(nb_disagreements) + + return torch.cat(record, dim=0) def degrade_input_to_generate(input, mask_generate, steps_nb_iterations): @@ -1152,20 +1173,7 @@ def c_quiz_criterion_some(probas): def generate_ae_c_quizzes(models, nb, local_device=main_device): - criteria = [ - c_quiz_criterion_few_good_one_bad, - # c_quiz_criterion_only_one, - # c_quiz_criterion_one_good_one_bad, - # c_quiz_criterion_one_good_no_very_bad, - # c_quiz_criterion_diff, - # c_quiz_criterion_diff2, - # c_quiz_criterion_two_good, - # c_quiz_criterion_some, - ] - # To be thread-safe we must make copies - models = [copy.deepcopy(model).to(local_device) for model in models] - quad_order = ("A", "f_A", "B", "f_B") template = quiz_machine.problem.create_empty_quizzes( @@ -1176,45 +1184,57 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1) ) - duration_max = 4 * 3600 + def copy_for_inference(model): + return copy.deepcopy(model).to(local_device).eval() wanted_nb = nb nb_to_save = 256 with torch.autograd.no_grad(): - records = [[] for _ in criteria] + records = [] last_log = -1 start_time = time.perf_counter() - while ( - time.perf_counter() < start_time + duration_max - and min([bag_len(bag) for bag in records]) < wanted_nb - ): - model = models[torch.randint(len(models), (1,)).item()] + while bag_len(records) < wanted_nb: + model = copy_for_inference(models[torch.randint(len(models), (1,)).item()]) + c_quizzes = ae_generate(model, template, mask_generate) to_keep = quiz_machine.problem.trivial(c_quizzes) == False c_quizzes = c_quizzes[to_keep] if c_quizzes.size(0) > 0: - probas = torch.cat( - [ - model_ae_proba_solutions(model, c_quizzes)[:, None] - for model in models - ], - dim=1, - ) + # p = [ + # model_ae_proba_solutions(model, c_quizzes)[:, None] + # for model in models + # ] + + # probas = torch.cat(p, dim=1) + # to_keep = c_quiz_criterion_two_good(probas) - for c, r in zip(criteria, records): - q = c_quizzes[c(probas)] - if q.size(0) > 0: - r.append(q) + nb_disagreements = [] + for model in models: + model = copy_for_inference(model) + nb_disagreements.append( + model_ae_argmax_nb_disagreements(model, c_quizzes).long()[ + :, None + ] + ) + nb_disagreements = torch.cat(nb_disagreements, dim=1) + + v = nb_disagreements.sort(dim=1).values + to_keep = (v[:, 1] == 0) & (v[:, -1] > 3) + + q = c_quizzes[to_keep] + + if q.size(0) > 0: + records.append(q) duration = time.perf_counter() - start_time - nb_generated = min([bag_len(bag) for bag in records]) + nb_generated = bag_len(records) - if last_log < 0 or duration > last_log + 60: + if last_log < 0 or duration > last_log + 5: last_log = duration if nb_generated > 0: if nb_generated < wanted_nb: @@ -1227,44 +1247,46 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): else: e = "???" - bl = [bag_len(bag) for bag in records] log_string( - f"bag_len {bl} model {model.id} (finishes {e} -- {int((nb_generated * 3600)/duration)}/h)" + f"nb_generated {bag_len(records)} model {model.id} (finishes {e} -- {int((nb_generated * 3600)/duration)}/h)" ) duration = time.perf_counter() - start_time log_string(f"generate_c_quizz_speed {int(3600 * wanted_nb / duration)}/h") - for n, u in enumerate(records): - quizzes = torch.cat(u, dim=0)[:nb_to_save] - filename = f"culture_c_quiz_{n_epoch:04d}_{n:02d}.png" + c_quizzes = torch.cat(records, dim=0).unique(dim=0) - # c_quizzes, predicted_parts, correct_parts = bag_to_tensors(record) + subset_c_quizzes = c_quizzes[:nb_to_save] - l = [model_ae_proba_solutions(model, quizzes) for model in models] - probas = torch.cat([x[:, None] for x in l], dim=1) - comments = [] + filename = f"culture_c_quiz_{n_epoch:04d}.png" - for l in probas: - comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) + # c_quizzes, predicted_parts, correct_parts = bag_to_tensors(record) - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - filename, - quizzes=quizzes, - # predicted_parts=predicted_parts, - # correct_parts=correct_parts, - comments=comments, - delta=True, - nrow=8, - ) + l = [ + model_ae_proba_solutions(copy_for_inference(model), subset_c_quizzes) + for model in models + ] + probas = torch.cat([x[:, None] for x in l], dim=1) + comments = [] - log_string(f"wrote {filename}") + for l in probas: + comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) - a = [torch.cat(u, dim=0) for u in records] + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + filename, + quizzes=subset_c_quizzes, + # predicted_parts=predicted_parts, + # correct_parts=correct_parts, + comments=comments, + delta=True, + nrow=8, + ) + + log_string(f"wrote {filename}") - return torch.cat(a, dim=0).unique(dim=0) + return c_quizzes def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device): -- 2.39.5