From bd5c4cc653837c63de41073c1ab5c6fb8404a156 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 15 Aug 2024 16:58:40 +0200 Subject: [PATCH] Update. --- main.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index 27404e8..e2e9a59 100755 --- a/main.py +++ b/main.py @@ -364,7 +364,7 @@ def run_tests(model, quiz_machine, local_device=main_device): log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}") - model.main_test_accuracy = quiz_machine.produce_results( + model.test_accuracy = quiz_machine.produce_results( n_epoch=n_epoch, model=model, input=full_input[:2000], @@ -890,7 +890,7 @@ for k in range(args.nb_gpts): else: model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) - model.main_test_accuracy = 0.0 + model.test_accuracy = 0.0 models.append(model) @@ -906,7 +906,9 @@ if args.resume: d = torch.load(os.path.join(args.result_dir, filename)) model.load_state_dict(d["state_dict"]) model.optimizer.load_state_dict(d["optimizer_state_dict"]) - model.main_test_accuracy = d["main_test_accuracy"] + model.test_accuracy = d["test_accuracy"] + model.best_test_accuracy = d["best_test_accuracy"] + model.best_dict = d["best_dict"] model.train_c_quiz_bags = d["train_c_quiz_bags"] model.test_c_quiz_bags = d["test_c_quiz_bags"] log_string(f"successfully loaded {filename}") @@ -1109,14 +1111,23 @@ for n_epoch in range(current_epoch, args.nb_epochs): log_string(f"--- epoch {n_epoch} ----------------------------------------") - cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models]) + cta = " ".join([f"{float(m.test_accuracy):.04f}" for m in models]) log_string(f"current_test_accuracies {cta}") ################################################## # If all the models are good enough, generate new quizzes and # re-compute the test errors - if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes: + for model in models: + if model.test_accuracy >= model.best_test_accuracy: + model.best_dict = copy.deepcopy(model.state_dict()) + model.best_test_accuracy = model.test_accuracy + + if min([m.best_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes: + for model in models: + model.current_dict = copy.deepcopy(model.state_dict()) + model.load_state_dict(model.best_dict) + record_new_c_quizzes( models, quiz_machine, @@ -1126,12 +1137,12 @@ for n_epoch in range(current_epoch, args.nb_epochs): # Force one epoch of training for model in models: - model.main_test_accuracy = 0.0 + model.load_state_dict(model.current_dict) ################################################## # Select, improve, and eval the worst model(s) - ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy)) + ranked_models = sorted(models, key=lambda m: float(m.test_accuracy)) weakest_models = ranked_models[: len(gpus)] @@ -1159,7 +1170,9 @@ for n_epoch in range(current_epoch, args.nb_epochs): { "state_dict": model.state_dict(), "optimizer_state_dict": model.optimizer.state_dict(), - "main_test_accuracy": model.main_test_accuracy, + "test_accuracy": model.test_accuracy, + "best_test_accuracy": model.best_test_accuracy, + "best_dict": model.best_dict, "train_c_quiz_bags": model.train_c_quiz_bags, "test_c_quiz_bags": model.test_c_quiz_bags, }, -- 2.39.5