From ea959eb80f58a53c81f4e57aa2a0cf713498c7fa Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 13 Aug 2024 14:26:55 +0200 Subject: [PATCH] Update. --- main.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/main.py b/main.py index bd46948..dda62af 100755 --- a/main.py +++ b/main.py @@ -165,7 +165,7 @@ assert not args.grids_science_tasks or ( default_args = { "model": "37M", "batch_size": 25, - "inference_batch_size": 50, + "inference_batch_size": 25, "nb_train_samples": 40000, "nb_test_samples": 1000, } @@ -806,19 +806,13 @@ if args.resume: model.load_state_dict(d["state_dict"]) model.optimizer.load_state_dict(d["optimizer_state_dict"]) model.main_test_accuracy = d["main_test_accuracy"] + model.train_c_quiz_bags = d["train_c_quiz_bags"] + model.test_c_quiz_bags = d["test_c_quiz_bags"] log_string(f"successfully loaded {filename}") except FileNotFoundError: log_string(f"cannot find {filename}") pass - try: - filename = "c_quizzes.pth" - quiz_machine.load_c_quizzes(os.path.join(args.result_dir, filename)) - log_string(f"successfully loaded {filename}") - except FileNotFoundError: - log_string(f"cannot find {filename}") - pass - try: filename = "state.pth" state = torch.load(os.path.join(args.result_dir, filename)) @@ -878,10 +872,6 @@ for n_epoch in range(current_epoch, args.nb_epochs): args.nb_new_c_quizzes_for_test, ) - filename = "c_quizzes.pth" - quiz_machine.save_c_quizzes(os.path.join(args.result_dir, filename)) - log_string(f"wrote {filename}") - # Force one epoch of training for model in models: model.main_test_accuracy = 0.0 @@ -918,6 +908,8 @@ for n_epoch in range(current_epoch, args.nb_epochs): "state_dict": model.state_dict(), "optimizer_state_dict": model.optimizer.state_dict(), "main_test_accuracy": model.main_test_accuracy, + "train_c_quiz_bags": model.train_c_quiz_bags, + "test_c_quiz_bags": model.test_c_quiz_bags, }, os.path.join(args.result_dir, filename), ) -- 2.39.5