From: François Fleuret Date: Fri, 12 Jul 2024 13:33:23 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=ee6d81c8669831d80e12fbaa021b7a9c37b54fba;p=culture.git Update. --- diff --git a/main.py b/main.py index b4ab473..87a67c3 100755 --- a/main.py +++ b/main.py @@ -547,6 +547,27 @@ if args.dirty_debug: return l[:, 0] < math.log(0.5) +###################################################################### + +nb_loaded_models = 0 + +for model in models: + filename = f"gpt_{model.id:03d}.pth" + + try: + model.load_state_dict(torch.load(os.path.join(args.result_dir, filename))) + log_string(f"model {model.id} successfully loaded from checkpoint.") + nb_loaded_models += 1 + + except FileNotFoundError: + log_string(f"starting model {model.id} from scratch.") + + except: + log_string(f"error when loading {filename}.") + exit(1) + +assert nb_loaded_models == 0 or nb_loaded_models == len(models) + ###################################################################### for n_epoch in range(args.nb_epochs): @@ -578,6 +599,10 @@ for n_epoch in range(args.nb_epochs): for t in threads: t.join() + for model in weakest_models: + filename = f"gpt_{model.id:03d}.pth" + torch.save(model.state_dict(), os.path.join(args.result_dir, filename)) + ################################################## # Replace a fraction of the w_quizzes with fresh ones