From 7b716a85786247b292ee71a635c98a18c66b421d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 12 Jul 2024 15:58:04 +0200 Subject: [PATCH] Update. --- grids.py | 10 ++++++---- main.py | 42 ++++++++++++++++++++---------------------- quiz_machine.py | 6 +++--- sky.py | 4 ++-- 4 files changed, 31 insertions(+), 31 deletions(-) diff --git a/grids.py b/grids.py index 5dad6f3..002a33f 100755 --- a/grids.py +++ b/grids.py @@ -996,7 +996,7 @@ class Grids(problem.Problem): return prompts.flatten(1), answers.flatten(1) - def save_quizzes( + def save_quiz_illustrations( self, result_dir, filename_prefix, @@ -1021,7 +1021,7 @@ class Grids(problem.Problem): for t in self.all_tasks: print(t.__name__) prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t]) - self.save_quizzes( + self.save_quiz_illustrations( result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow ) @@ -1056,7 +1056,9 @@ if __name__ == "__main__": ]: print(t.__name__) prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) - grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow) + grids.save_quiz_illustrations( + "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow + ) exit(0) @@ -1075,7 +1077,7 @@ if __name__ == "__main__": predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1) predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1) - grids.save_quizzes( + grids.save_quiz_illustrations( "/tmp", "test", prompts[:nb], diff --git a/main.py b/main.py index 87a67c3..8715711 100755 --- a/main.py +++ b/main.py @@ -443,11 +443,15 @@ def create_c_quizzes( q = new_c_quizzes[:72] if q.size(0) > 0: - quiz_machine.save_quizzes(args.result_dir, f"culture_c_quiz_{n_epoch:04d}", q) + quiz_machine.save_quiz_illustrations( + args.result_dir, f"culture_c_quiz_{n_epoch:04d}", q + ) ###################################################################### +nb_loaded_models = 0 + models = [] for k in range(args.nb_gpts): @@ -471,8 +475,23 @@ for k in range(args.nb_gpts): model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples) quiz_machine.reverse_random_half_in_place(model.test_w_quizzes) + filename = f"gpt_{model.id:03d}.pth" + + try: + model.load_state_dict(torch.load(os.path.join(args.result_dir, filename))) + log_string(f"model {model.id} successfully loaded from checkpoint.") + nb_loaded_models += 1 + + except FileNotFoundError: + log_string(f"starting model {model.id} from scratch.") + + except: + log_string(f"error when loading {filename}.") + exit(1) + models.append(model) +assert nb_loaded_models == 0 or nb_loaded_models == len(models) nb_parameters = sum(p.numel() for p in models[0].parameters()) log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") @@ -547,27 +566,6 @@ if args.dirty_debug: return l[:, 0] < math.log(0.5) -###################################################################### - -nb_loaded_models = 0 - -for model in models: - filename = f"gpt_{model.id:03d}.pth" - - try: - model.load_state_dict(torch.load(os.path.join(args.result_dir, filename))) - log_string(f"model {model.id} successfully loaded from checkpoint.") - nb_loaded_models += 1 - - except FileNotFoundError: - log_string(f"starting model {model.id} from scratch.") - - except: - log_string(f"error when loading {filename}.") - exit(1) - -assert nb_loaded_models == 0 or nb_loaded_models == len(models) - ###################################################################### for n_epoch in range(args.nb_epochs): diff --git a/quiz_machine.py b/quiz_machine.py index 631d41b..88fd9f1 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -241,7 +241,7 @@ class QuizMachine: self.train_c_quizzes = [] self.test_c_quizzes = [] - def save_quizzes( + def save_quiz_illustrations( self, result_dir, filename_prefix, @@ -266,7 +266,7 @@ class QuizMachine: predicted_prompts *= 2 predicted_answers *= 2 - self.problem.save_quizzes( + self.problem.save_quiz_illustrations( result_dir, filename_prefix, quizzes[:, 1 : 1 + self.prompt_len], @@ -384,7 +384,7 @@ class QuizMachine: ############################## - self.save_quizzes( + self.save_quiz_illustrations( result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}", quizzes=test_result[:72], diff --git a/sky.py b/sky.py index 1768a81..cc5bd4f 100755 --- a/sky.py +++ b/sky.py @@ -300,7 +300,7 @@ class Sky(problem.Problem): return prompts, answers - def save_quizzes( + def save_quiz_illustrations( self, result_dir, filename_prefix, @@ -331,7 +331,7 @@ if __name__ == "__main__": predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1 predicted_answers = torch.randint(3, (prompts.size(0),)) - 1 - sky.save_quizzes( + sky.save_quiz_illustrations( "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers ) -- 2.39.5