From abe23f86eff172717249552e9abf5b07c3be054e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 19 Aug 2024 20:38:42 +0200 Subject: [PATCH] Update. --- main.py | 74 ++++++++++++++++++++++++++++----------------------------- 1 file changed, 36 insertions(+), 38 deletions(-) diff --git a/main.py b/main.py index 1cbff39..cd1e10f 100755 --- a/main.py +++ b/main.py @@ -93,11 +93,7 @@ parser.add_argument("--gpus", type=str, default="all") # ---------------------------------- -parser.add_argument("--nb_gpts", type=int, default=5) - -parser.add_argument("--min_succeed_to_validate", type=int, default=2) - -parser.add_argument("--max_fail_to_validate", type=int, default=3) +parser.add_argument("--nb_gpts", type=int, default=2) parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95) @@ -340,8 +336,9 @@ def run_tests(model, quiz_machine, local_device=main_device): nb_samples_accumulated = 0 full_input, full_mask_loss = quiz_machine.data_input( - args.nb_test_samples, model.test_c_quiz_bags + args.nb_test_samples, test_c_quiz_bags ) + src = zip( full_input.split(args.batch_size), full_mask_loss.split(args.batch_size) ) @@ -368,7 +365,7 @@ def run_tests(model, quiz_machine, local_device=main_device): log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}") - input, _ = quiz_machine.data_input(2000, model.test_c_quiz_bags) + input, _ = quiz_machine.data_input(1000, test_c_quiz_bags) model.test_accuracy = quiz_machine.produce_results( n_epoch=n_epoch, @@ -391,7 +388,7 @@ def one_epoch(model, quiz_machine, local_device=main_device): nb_train_samples, acc_train_loss = 0, 0.0 full_input, full_mask_loss = quiz_machine.data_input( - args.nb_train_samples, model.train_c_quiz_bags + args.nb_train_samples, train_c_quiz_bags ) src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)) @@ -557,7 +554,15 @@ def model_proba_solutions(model, quizzes): return l.exp() -def create_c_quizzes(main_model, other_models, quiz_machine, nb_for_train, nb_for_test): +def create_c_quizzes( + main_model, + other_models, + quiz_machine, + nb_for_train, + train_c_quiz_bags, + nb_for_test, + test_c_quiz_bags, +): nb_validated, nb_to_validate = 0, (nb_for_train + nb_for_test) * len(models) nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate @@ -641,7 +646,7 @@ def create_c_quizzes(main_model, other_models, quiz_machine, nb_for_train, nb_fo e = "???" log_string( - f"keep c_quizzes model {model_for_generation.id} validated nb_validated {nb_validated} / {nb_to_validate} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h) proportion_kept {nb_validated * 100 / nb_generated:.02f}%" + f"keep c_quizzes model {main_model.id} validated nb_validated {nb_validated} / {nb_to_validate} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h) proportion_kept {nb_validated * 100 / nb_generated:.02f}%" ) # Save some images @@ -661,10 +666,9 @@ def create_c_quizzes(main_model, other_models, quiz_machine, nb_for_train, nb_fo args.result_dir, filename, c_quizzes[:128], comments=comments ) - -log_string( - f"nb_c_quizzes model {model.id} train {sum([q.size(0) for q in model.train_c_quiz_bags ])} test {sum([q.size(0) for q in model.test_c_quiz_bags ])}" -) + log_string( + f"nb_c_quizzes model {model.id} train {sum([q.size(0) for q in train_c_quiz_bags ])} test {sum([q.size(0) for q in test_c_quiz_bags ])}" + ) ###################################################################### @@ -709,8 +713,6 @@ for k in range(args.nb_gpts): ) model.id = k - model.train_c_quiz_bags = [] - model.test_c_quiz_bags = [] if args.schedule_free: model.optimizer = schedulefree.AdamWScheduleFree( @@ -724,6 +726,9 @@ for k in range(args.nb_gpts): ###################################################################### +train_c_quiz_bags = [] +test_c_quiz_bags = [] + current_epoch = 0 if args.resume: @@ -735,8 +740,6 @@ if args.resume: model.load_state_dict(d["state_dict"]) model.optimizer.load_state_dict(d["optimizer_state_dict"]) model.test_accuracy = d["test_accuracy"] - model.train_c_quiz_bags = d["train_c_quiz_bags"] - model.test_c_quiz_bags = d["test_c_quiz_bags"] log_string(f"successfully loaded {filename}") except FileNotFoundError: log_string(f"cannot find {filename}") @@ -747,6 +750,8 @@ if args.resume: state = torch.load(os.path.join(args.result_dir, filename)) log_string(f"successfully loaded {filename}") current_epoch = state["current_epoch"] + train_c_quiz_bags = d["train_c_quiz_bags"] + test_c_quiz_bags = d["test_c_quiz_bags"] except FileNotFoundError: log_string(f"cannot find {filename}") pass @@ -759,10 +764,10 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") ###################################################################### if args.nb_new_c_quizzes_for_train is None: - args.nb_new_c_quizzes_for_train = args.nb_train_samples // 250 + args.nb_new_c_quizzes_for_train = args.nb_train_samples if args.nb_new_c_quizzes_for_test is None: - args.nb_new_c_quizzes_for_test = args.nb_test_samples // 250 + args.nb_new_c_quizzes_for_test = args.nb_test_samples log_string( f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}" @@ -850,6 +855,8 @@ def save_generated_c_quizzes(model, filename, nb=64): for n_epoch in range(current_epoch, args.nb_epochs): state = { "current_epoch": n_epoch, + "train_c_quiz_bags": train_c_quiz_bags, + "test_c_quiz_bags": test_c_quiz_bags, } filename = "state.pth" torch.save(state, os.path.join(args.result_dir, filename)) @@ -863,11 +870,14 @@ for n_epoch in range(current_epoch, args.nb_epochs): ################################################## if min([m.test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes: - record_new_c_quizzes( - models, - quiz_machine, - args.nb_new_c_quizzes_for_train, - args.nb_new_c_quizzes_for_test, + create_c_quizzes( + main_model=models[0], + other_models=models[1:], + quiz_machine=quiz_machine, + nb_for_train=args.nb_new_c_quizzes_for_train, + train_c_quiz_bags=train_c_quiz_bags, + nb_for_test=args.nb_new_c_quizzes_for_test, + test_c_quiz_bags=test_c_quiz_bags, ) for model in models: @@ -883,8 +893,6 @@ for n_epoch in range(current_epoch, args.nb_epochs): ).to(main_device) model.load_state_dict(new_model.state_dict()) model.test_accuracy = 0.0 - model.best_test_accuracy = 0.0 - model.best_dict = copy.deepcopy(model.state_dict()) ################################################## # Select, improve, and eval the worst model(s) @@ -894,11 +902,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): # This ugly recipe will pick the worst if there some below # args.accuracy_to_make_c_quizzes or one at random if they # are all above - key=lambda m: float( - m.test_accuracy - if m.test_accuracy < args.accuracy_to_make_c_quizzes - else args.accuracy_to_make_c_quizzes + torch.rand(1).item() - ), + key=lambda m: float(m.test_accuracy), ) weakest_models = ranked_models[: len(gpus)] @@ -921,8 +925,6 @@ for n_epoch in range(current_epoch, args.nb_epochs): for t in threads: t.join() - total_time_training_models += time.perf_counter() - start_time - for model in weakest_models: save_additional_results(n_epoch, model, models, c_quizzes_procedure) @@ -935,10 +937,6 @@ for n_epoch in range(current_epoch, args.nb_epochs): "state_dict": model.state_dict(), "optimizer_state_dict": model.optimizer.state_dict(), "test_accuracy": model.test_accuracy, - "best_test_accuracy": model.best_test_accuracy, - "best_dict": model.best_dict, - "train_c_quiz_bags": model.train_c_quiz_bags, - "test_c_quiz_bags": model.test_c_quiz_bags, }, os.path.join(args.result_dir, filename), ) -- 2.39.5