Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 12 Jul 2024 13:58:04 +0000 (15:58 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 12 Jul 2024 13:58:04 +0000 (15:58 +0200)
grids.py
main.py
quiz_machine.py
sky.py

index 5dad6f3..002a33f 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -996,7 +996,7 @@ class Grids(problem.Problem):
 
         return prompts.flatten(1), answers.flatten(1)
 
-    def save_quizzes(
+    def save_quiz_illustrations(
         self,
         result_dir,
         filename_prefix,
@@ -1021,7 +1021,7 @@ class Grids(problem.Problem):
         for t in self.all_tasks:
             print(t.__name__)
             prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t])
-            self.save_quizzes(
+            self.save_quiz_illustrations(
                 result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow
             )
 
@@ -1056,7 +1056,9 @@ if __name__ == "__main__":
     ]:
         print(t.__name__)
         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
-        grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow)
+        grids.save_quiz_illustrations(
+            "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow
+        )
 
     exit(0)
 
@@ -1075,7 +1077,7 @@ if __name__ == "__main__":
     predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
     predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
 
-    grids.save_quizzes(
+    grids.save_quiz_illustrations(
         "/tmp",
         "test",
         prompts[:nb],
diff --git a/main.py b/main.py
index 87a67c3..8715711 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -443,11 +443,15 @@ def create_c_quizzes(
     q = new_c_quizzes[:72]
 
     if q.size(0) > 0:
-        quiz_machine.save_quizzes(args.result_dir, f"culture_c_quiz_{n_epoch:04d}", q)
+        quiz_machine.save_quiz_illustrations(
+            args.result_dir, f"culture_c_quiz_{n_epoch:04d}", q
+        )
 
 
 ######################################################################
 
+nb_loaded_models = 0
+
 models = []
 
 for k in range(args.nb_gpts):
@@ -471,8 +475,23 @@ for k in range(args.nb_gpts):
     model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples)
     quiz_machine.reverse_random_half_in_place(model.test_w_quizzes)
 
+    filename = f"gpt_{model.id:03d}.pth"
+
+    try:
+        model.load_state_dict(torch.load(os.path.join(args.result_dir, filename)))
+        log_string(f"model {model.id} successfully loaded from checkpoint.")
+        nb_loaded_models += 1
+
+    except FileNotFoundError:
+        log_string(f"starting model {model.id} from scratch.")
+
+    except:
+        log_string(f"error when loading {filename}.")
+        exit(1)
+
     models.append(model)
 
+assert nb_loaded_models == 0 or nb_loaded_models == len(models)
 
 nb_parameters = sum(p.numel() for p in models[0].parameters())
 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
@@ -547,27 +566,6 @@ if args.dirty_debug:
         return l[:, 0] < math.log(0.5)
 
 
-######################################################################
-
-nb_loaded_models = 0
-
-for model in models:
-    filename = f"gpt_{model.id:03d}.pth"
-
-    try:
-        model.load_state_dict(torch.load(os.path.join(args.result_dir, filename)))
-        log_string(f"model {model.id} successfully loaded from checkpoint.")
-        nb_loaded_models += 1
-
-    except FileNotFoundError:
-        log_string(f"starting model {model.id} from scratch.")
-
-    except:
-        log_string(f"error when loading {filename}.")
-        exit(1)
-
-assert nb_loaded_models == 0 or nb_loaded_models == len(models)
-
 ######################################################################
 
 for n_epoch in range(args.nb_epochs):
index 631d41b..88fd9f1 100755 (executable)
@@ -241,7 +241,7 @@ class QuizMachine:
         self.train_c_quizzes = []
         self.test_c_quizzes = []
 
-    def save_quizzes(
+    def save_quiz_illustrations(
         self,
         result_dir,
         filename_prefix,
@@ -266,7 +266,7 @@ class QuizMachine:
             predicted_prompts *= 2
             predicted_answers *= 2
 
-        self.problem.save_quizzes(
+        self.problem.save_quiz_illustrations(
             result_dir,
             filename_prefix,
             quizzes[:, 1 : 1 + self.prompt_len],
@@ -384,7 +384,7 @@ class QuizMachine:
 
         ##############################
 
-        self.save_quizzes(
+        self.save_quiz_illustrations(
             result_dir,
             f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
             quizzes=test_result[:72],
diff --git a/sky.py b/sky.py
index 1768a81..cc5bd4f 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -300,7 +300,7 @@ class Sky(problem.Problem):
 
         return prompts, answers
 
-    def save_quizzes(
+    def save_quiz_illustrations(
         self,
         result_dir,
         filename_prefix,
@@ -331,7 +331,7 @@ if __name__ == "__main__":
     predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1
     predicted_answers = torch.randint(3, (prompts.size(0),)) - 1
 
-    sky.save_quizzes(
+    sky.save_quiz_illustrations(
         "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers
     )