From: François Fleuret Date: Sat, 7 Sep 2024 07:20:27 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=69b6a2d467008e14de972954405b249b62e6b413;p=culture.git Update. --- diff --git a/main.py b/main.py index b926f8e..264b5c7 100755 --- a/main.py +++ b/main.py @@ -991,9 +991,7 @@ def ae_generate(model, x_0, mask_generate, nb_iterations_max=50): hat_x_t_minus_1 = one_iteration_prediction * hat_x_0 + ( 1 - one_iteration_prediction - ) * sample_x_t_minus_1_given_x_0_x_t( - hat_x_0, x_t, max(1, args.nb_diffusion_iterations - it) - ) + ) * sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t) if hat_x_t_minus_1.equal(x_t): # log_string(f"exit after {it+1} iterations") @@ -1035,11 +1033,11 @@ def model_ae_proba_solutions(model, input, log_proba=False): return (-loss).exp() -def model_ae_argmax_nb_disagreements(model, input): +def model_ae_argmax_nb_mistakes(model, input): record = [] for x_0 in input.split(args.batch_size): - nb_disagreements = 0 + nb_mistakes = 0 for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]: mask_generate = quiz_machine.make_quiz_mask( quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad @@ -1050,11 +1048,11 @@ def model_ae_argmax_nb_disagreements(model, input): predicted = logits.argmax(dim=-1) - nb_disagreements = nb_disagreements + ( + nb_mistakes = nb_mistakes + ( mask_generate * predicted != mask_generate * x_0 ).long().sum(dim=1) - record.append(nb_disagreements) + record.append(nb_mistakes) return torch.cat(record, dim=0) @@ -1275,40 +1273,37 @@ def save_badness_statistics( ###################################################################### -def c_quiz_criterion_one_good_one_bad(probas): - return (probas.max(dim=1).values >= 0.75) & (probas.min(dim=1).values <= 0.25) - - -def c_quiz_criterion_one_good_no_very_bad(probas): - return ( - (probas.max(dim=1).values >= 0.75) - & (probas.min(dim=1).values <= 0.75) - & (probas.min(dim=1).values >= 0.25) - ) - - -def c_quiz_criterion_diff(probas): - return (probas.max(dim=1).values - probas.min(dim=1).values) >= 0.5 +def quiz_validation(models, c_quizzes, local_device): + nb_have_to_be_correct = args.nb_models // 2 + nb_have_to_be_wrong = args.nb_models // 5 + nb_runs = 3 + nb_mistakes_to_be_wrong = 5 -def c_quiz_criterion_diff2(probas): - v = probas.sort(dim=1).values - return (v[:, -2] - v[:, 0]) >= 0.5 + record_wrong = [] + nb_correct, nb_wrong = 0, 0 + for i, model in enumerate(models): + assert i == model.id # a bit of paranoia + model = copy.deepcopy(model).to(local_device).eval() + correct, wrong = True, False + for _ in range(nb_runs): + n = model_ae_argmax_nb_mistakes(model, c_quizzes).long() + correct = correct & (n == 0) + wrong = wrong | (n >= nb_mistakes_to_be_wrong) + record_wrong.append(wrong[:, None]) + nb_correct += correct.long() + nb_wrong += wrong.long() -def c_quiz_criterion_few_good_one_bad(probas): - v = probas.sort(dim=1).values - return (v[:, 0] <= 0.25) & (v[:, -3] >= 0.5) + # print("nb_correct", nb_correct) + # print("nb_wrong", nb_wrong) -def c_quiz_criterion_two_good(probas): - return ((probas >= 0.5).long().sum(dim=1) >= 2) & (probas.min(dim=1).values <= 0.2) + to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong) + wrong = torch.cat(record_wrong, dim=1) -def c_quiz_criterion_some(probas): - return ((probas >= 0.8).long().sum(dim=1) >= 1) & ( - (probas <= 0.2).long().sum(dim=1) >= 1 - ) + return to_keep, wrong def generate_ae_c_quizzes(models, nb, local_device=main_device): @@ -1346,33 +1341,12 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): c_quizzes = c_quizzes[to_keep] if c_quizzes.size(0) > 0: - # p = [ - # model_ae_proba_solutions(model, c_quizzes)[:, None] - # for model in models - # ] - - # probas = torch.cat(p, dim=1) - # to_keep = c_quiz_criterion_two_good(probas) - - nb_disagreements = [] - for i, model in enumerate(models): - assert i == model.id # a bit of paranoia - model = copy_for_inference(model) - nb_disagreements.append( - model_ae_argmax_nb_disagreements(model, c_quizzes).long()[ - :, None - ] - ) - nb_disagreements = torch.cat(nb_disagreements, dim=1) - - v = nb_disagreements.sort(dim=1).values - to_keep = (v[:, 2] == 0) & (v[:, -1] >= 4) - + to_keep, record_wrong = quiz_validation(models, c_quizzes, local_device) q = c_quizzes[to_keep] if q.size(0) > 0: record_c_quizzes.append(q) - a = (nb_disagreements == 0)[to_keep] + a = (record_wrong == False)[to_keep] record_agreements.append(a) nb_c_quizzes_per_model += a.long().sum(dim=0) @@ -1405,25 +1379,23 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): subset_c_quizzes = c_quizzes[:nb_to_save] - # #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # for model in models: - # model = copy_for_inference(model) - # prediction = model_ae_argmax_predictions(model, subset_c_quizzes) - # filename = f"prediction_c_quiz_{n_epoch:04d}_{model.id}.png" + # for r in range(3): + # filename = f"culture_c_quiz_{n_epoch:04d}_prediction_{model.id}_{r}.png" + # p = model_ae_argmax_predictions(copy_for_inference(model), subset_c_quizzes) # quiz_machine.problem.save_quizzes_as_image( # args.result_dir, # filename, - # quizzes=prediction, + # quizzes=p, + # delta=True, # nrow=8, # ) # log_string(f"wrote {filename}") - # exit(0) - #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! filename = f"culture_c_quiz_{n_epoch:04d}.png" - # c_quizzes, predicted_parts, correct_parts = bag_to_tensors(record) - l = [ model_ae_proba_solutions(copy_for_inference(model), subset_c_quizzes) for model in models @@ -1438,8 +1410,6 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): args.result_dir, filename, quizzes=subset_c_quizzes, - # predicted_parts=predicted_parts, - # correct_parts=correct_parts, comments=comments, delta=True, nrow=8, @@ -1482,9 +1452,6 @@ if args.resume: log_string(f"successfully loaded {filename}") current_epoch = state["current_epoch"] c_quizzes = state["c_quizzes"] - # total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"] - # total_time_training_models = state["total_time_training_models"] - # common_c_quiz_bags = state["common_c_quiz_bags"] except FileNotFoundError: log_string(f"cannot find {filename}") pass @@ -1510,9 +1477,6 @@ for n_epoch in range(current_epoch, args.nb_epochs): state = { "current_epoch": n_epoch, "c_quizzes": c_quizzes, - # "total_time_generating_c_quizzes": total_time_generating_c_quizzes, - # "total_time_training_models": total_time_training_models, - # "common_c_quiz_bags": common_c_quiz_bags, } filename = "state.pth" @@ -1526,28 +1490,27 @@ for n_epoch in range(current_epoch, args.nb_epochs): # -------------------------------------------------------------------- - # run_ae_test( - # model, - # alien_quiz_machine, - # n_epoch, - # c_quizzes=None, - # local_device=main_device, - # prefix="alien", - # ) - - # exit(0) - - # one_ae_epoch(models[0], quiz_machine, n_epoch, None, main_device) - # exit(0) - log_string(f"{time_train=} {time_c_quizzes=}") if ( min([float(m.test_accuracy) for m in models]) > args.accuracy_to_make_c_quizzes and time_train >= time_c_quizzes ): - if c_quizzes is not None: - save_badness_statistics(last_n_epoch_c_quizzes, models, c_quizzes, "after") + if c_quizzes is None: + for model in models: + filename = f"ae_{model.id:03d}_naive.pth" + torch.save( + { + "state_dict": model.state_dict(), + "optimizer_state_dict": model.optimizer.state_dict(), + "test_accuracy": model.test_accuracy, + }, + os.path.join(args.result_dir, filename), + ) + + log_string(f"wrote {filename}") + + # -------------------------------------------------------------------- last_n_epoch_c_quizzes = n_epoch nb_gpus = len(gpus) @@ -1579,8 +1542,6 @@ for n_epoch in range(current_epoch, args.nb_epochs): c_quizzes = torch.cat([q.to(main_device) for q, _ in records], dim=0) agreements = torch.cat([a.to(main_device) for _, a in records], dim=0) - print(f"DEBUG {c_quizzes.size()=} {agreements.size()=}") - # -------------------------------------------------------------------- log_string(f"generated_c_quizzes {c_quizzes.size()=}") @@ -1589,8 +1550,6 @@ for n_epoch in range(current_epoch, args.nb_epochs): for model in models: model.test_accuracy = 0 - save_badness_statistics(n_epoch, models, c_quizzes, "before") - if c_quizzes is None: log_string("no_c_quiz") else: @@ -1603,9 +1562,6 @@ for n_epoch in range(current_epoch, args.nb_epochs): threads = [] - # for model in models: - # log_string(f"DEBUG {model.id} {sum([ p.sum() for p in model.parameters()]).item()}") - start_time = time.perf_counter() for gpu, model in zip(gpus, weakest_models): @@ -1639,10 +1595,6 @@ for n_epoch in range(current_epoch, args.nb_epochs): "state_dict": model.state_dict(), "optimizer_state_dict": model.optimizer.state_dict(), "test_accuracy": model.test_accuracy, - # "gen_test_accuracy": model.gen_test_accuracy, - # "gen_state_dict": model.gen_state_dict, - # "train_c_quiz_bags": model.train_c_quiz_bags, - # "test_c_quiz_bags": model.test_c_quiz_bags, }, os.path.join(args.result_dir, filename), )