Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 12 Jul 2024 13:33:23 +0000 (15:33 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 12 Jul 2024 13:33:23 +0000 (15:33 +0200)
main.py

diff --git a/main.py b/main.py
index b4ab473..87a67c3 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -547,6 +547,27 @@ if args.dirty_debug:
         return l[:, 0] < math.log(0.5)
 
 
         return l[:, 0] < math.log(0.5)
 
 
+######################################################################
+
+nb_loaded_models = 0
+
+for model in models:
+    filename = f"gpt_{model.id:03d}.pth"
+
+    try:
+        model.load_state_dict(torch.load(os.path.join(args.result_dir, filename)))
+        log_string(f"model {model.id} successfully loaded from checkpoint.")
+        nb_loaded_models += 1
+
+    except FileNotFoundError:
+        log_string(f"starting model {model.id} from scratch.")
+
+    except:
+        log_string(f"error when loading {filename}.")
+        exit(1)
+
+assert nb_loaded_models == 0 or nb_loaded_models == len(models)
+
 ######################################################################
 
 for n_epoch in range(args.nb_epochs):
 ######################################################################
 
 for n_epoch in range(args.nb_epochs):
@@ -578,6 +599,10 @@ for n_epoch in range(args.nb_epochs):
     for t in threads:
         t.join()
 
     for t in threads:
         t.join()
 
+    for model in weakest_models:
+        filename = f"gpt_{model.id:03d}.pth"
+        torch.save(model.state_dict(), os.path.join(args.result_dir, filename))
+
     ##################################################
     # Replace a fraction of the w_quizzes with fresh ones
 
     ##################################################
     # Replace a fraction of the w_quizzes with fresh ones