From ee6d81c8669831d80e12fbaa021b7a9c37b54fba Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 12 Jul 2024 15:33:23 +0200 Subject: [PATCH] Update. --- main.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/main.py b/main.py index b4ab473..87a67c3 100755 --- a/main.py +++ b/main.py @@ -547,6 +547,27 @@ if args.dirty_debug: return l[:, 0] < math.log(0.5) +###################################################################### + +nb_loaded_models = 0 + +for model in models: + filename = f"gpt_{model.id:03d}.pth" + + try: + model.load_state_dict(torch.load(os.path.join(args.result_dir, filename))) + log_string(f"model {model.id} successfully loaded from checkpoint.") + nb_loaded_models += 1 + + except FileNotFoundError: + log_string(f"starting model {model.id} from scratch.") + + except: + log_string(f"error when loading {filename}.") + exit(1) + +assert nb_loaded_models == 0 or nb_loaded_models == len(models) + ###################################################################### for n_epoch in range(args.nb_epochs): @@ -578,6 +599,10 @@ for n_epoch in range(args.nb_epochs): for t in threads: t.join() + for model in weakest_models: + filename = f"gpt_{model.id:03d}.pth" + torch.save(model.state_dict(), os.path.join(args.result_dir, filename)) + ################################################## # Replace a fraction of the w_quizzes with fresh ones -- 2.20.1