From 2f87c91cf606a068de1450d198660de7e44cd356 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 12 Jul 2024 17:36:30 +0200 Subject: [PATCH] Update. --- main.py | 58 +++++++++++++++++++++++++++++++++---------------- quiz_machine.py | 6 +++++ 2 files changed, 45 insertions(+), 19 deletions(-) diff --git a/main.py b/main.py index 8715711..a8ceac8 100755 --- a/main.py +++ b/main.py @@ -32,6 +32,8 @@ parser.add_argument("--result_dir", type=str, default=None) parser.add_argument("--seed", type=int, default=0) +parser.add_argument("--resume", action="store_true", default=False) + parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1) ######################################## @@ -190,11 +192,15 @@ else: ###################################################################### -try: - os.mkdir(args.result_dir) -except FileExistsError: - print(f"result directory {args.result_dir} already exists") - exit(1) +if args.resume: + assert os.path.isdir(args.result_dir) + +else: + try: + os.mkdir(args.result_dir) + except FileExistsError: + print(f"result directory {args.result_dir} already exists") + exit(1) log_file = open(os.path.join(args.result_dir, args.log_filename), "a") @@ -437,8 +443,7 @@ def create_c_quizzes( quiz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True) quiz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False) - # save a bunch of images to investigate what quizzes with a - # certain nb of correct predictions look like + # save images q = new_c_quizzes[:72] @@ -450,8 +455,6 @@ def create_c_quizzes( ###################################################################### -nb_loaded_models = 0 - models = [] for k in range(args.nb_gpts): @@ -475,23 +478,37 @@ for k in range(args.nb_gpts): model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples) quiz_machine.reverse_random_half_in_place(model.test_w_quizzes) - filename = f"gpt_{model.id:03d}.pth" + models.append(model) + +###################################################################### +if args.resume: try: - model.load_state_dict(torch.load(os.path.join(args.result_dir, filename))) - log_string(f"model {model.id} successfully loaded from checkpoint.") - nb_loaded_models += 1 - - except FileNotFoundError: - log_string(f"starting model {model.id} from scratch.") + for model in models: + filename = f"gpt_{model.id:03d}.pth" + + try: + model.load_state_dict( + torch.load(os.path.join(args.result_dir, filename)) + ) + log_string(f"successfully loaded {filename}") + except FileNotFoundError: + log_string(f"cannot find {filename}") + pass + + try: + filename = "c_quizzes.pth" + quiz_machine.load_c_quizzes(os.path.join(args.result_dir, filename)) + log_string(f"successfully loaded {filename}") + except FileNotFoundError: + log_string(f"cannot find {filename}") + pass except: log_string(f"error when loading {filename}.") exit(1) - models.append(model) - -assert nb_loaded_models == 0 or nb_loaded_models == len(models) +###################################################################### nb_parameters = sum(p.numel() for p in models[0].parameters()) log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") @@ -600,6 +617,7 @@ for n_epoch in range(args.nb_epochs): for model in weakest_models: filename = f"gpt_{model.id:03d}.pth" torch.save(model.state_dict(), os.path.join(args.result_dir, filename)) + log_string(f"wrote {filename}") ################################################## # Replace a fraction of the w_quizzes with fresh ones @@ -625,4 +643,6 @@ for n_epoch in range(args.nb_epochs): nb_for_test=nb_new_c_quizzes_for_test, ) + quiz_machine.save_c_quizzes(os.path.join(args.result_dir, "c_quizzes.pth")) + ###################################################################### diff --git a/quiz_machine.py b/quiz_machine.py index 88fd9f1..c39bf7a 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -412,6 +412,12 @@ class QuizMachine: else: self.test_c_quizzes.append(new_c_quizzes.to("cpu")) + def save_c_quizzes(self, filename): + torch.save((self.train_c_quizzes, self.test_c_quizzes), filename) + + def load_c_quizzes(self, filename): + self.train_c_quizzes, self.test_c_quizzes = torch.load(filename) + ###################################################################### def logproba_of_solutions(self, models, c_quizzes): -- 2.39.5