Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 8 Aug 2024 06:21:23 +0000 (08:21 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 8 Aug 2024 06:21:23 +0000 (08:21 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 8f3568f..86eafea 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -395,8 +395,6 @@ def run_tests(model, quiz_machine, local_device=main_device):
 def one_epoch(model, quiz_machine, local_device=main_device):
     model.to(local_device).train()
 
-    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
-
     nb_train_samples, acc_train_loss = 0, 0.0
 
     hard_w_quizzes = []
@@ -413,7 +411,7 @@ def one_epoch(model, quiz_machine, local_device=main_device):
         input = input.to(local_device)
 
         if nb_train_samples % args.batch_size == 0:
-            optimizer.zero_grad()
+            model.optimizer.zero_grad()
 
         targets = input
 
@@ -435,7 +433,7 @@ def one_epoch(model, quiz_machine, local_device=main_device):
         loss.backward()
 
         if nb_train_samples % args.batch_size == 0:
-            optimizer.step()
+            model.optimizer.step()
 
     train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
 
@@ -470,6 +468,7 @@ c_quizzes_procedure = [
     (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot),
     (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold),
     (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
+    (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold),
     # (("f_B", "f_A", "A", "B"), (0, 0, 1, 1), model_transformer_cold),
     # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
 ]
@@ -489,22 +488,15 @@ def save_additional_results(model, models, science_w_quizzes):
         recorder=recorder,
     )
 
-    ##
-
-    probas = 0
+    # This is nb_quizzes x nb_models
 
-    for a in range(args.nb_averaging_rounds):
-        # This is nb_quizzes x nb_models
-
-        seq_logproba = quiz_machine.models_logprobas(
-            models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
-        ) + quiz_machine.models_logprobas(
-            models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
-        )
-
-        probas += seq_logproba.exp()
+    seq_logproba = quiz_machine.models_logprobas(
+        models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+    ) + quiz_machine.models_logprobas(
+        models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+    )
 
-    probas /= args.nb_averaging_rounds
+    probas = seq_logproba.exp()
 
     comments = []
 
@@ -597,8 +589,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
     nb_validated_per_model = torch.zeros(len(models), dtype=torch.int64)
 
-    to_recycle = None
-
     while nb_validated_per_model.sum() < nb_to_validate:
         # We use the model that has generated the fewest quizzes to
         # balance the number of quizzes per model overall
@@ -616,35 +606,24 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
             nb_to_generate_per_iteration,
             model_for_generation=model,
             procedure=c_quizzes_procedure,
-            to_recycle=to_recycle,
         )
 
         # We discard the trivial ones, according to a criterion
         # specific to the world quizzes (e.g. B=f(B))
 
-        rejected = []
-
         to_keep = quiz_machine.problem.trivial(c_quizzes) == False
 
-        if not to_keep.all():
-            rejected.append(c_quizzes[to_keep == False])
-
         c_quizzes = c_quizzes[to_keep]
 
-        probas = 0
-
-        for a in range(args.nb_averaging_rounds):
-            # This is nb_quizzes x nb_models
-
-            seq_logproba = quiz_machine.models_logprobas(
-                models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
-            ) + quiz_machine.models_logprobas(
-                models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
-            )
+        # This is nb_quizzes x nb_models
 
-            probas += seq_logproba.exp()
+        seq_logproba = quiz_machine.models_logprobas(
+            models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+        ) + quiz_machine.models_logprobas(
+            models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+        )
 
-        probas /= args.nb_averaging_rounds
+        probas = seq_logproba.exp()
 
         nb_succeed = (probas >= args.proba_understands).long().sum(dim=1)
         nb_fail = (probas <= args.proba_not_understands).long().sum(dim=1)
@@ -655,7 +634,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
             & (nb_fail <= args.max_fail_to_validate)
         )
 
-        to_recycle = c_quizzes[to_keep == False]
         c_quizzes = c_quizzes[to_keep]
 
         if c_quizzes.size(0) > 0:
@@ -1010,7 +988,6 @@ def train_complexifier(model_gen, model_pred1, model_pred2):
 
 ######################################################################
 
-
 models = []
 
 for k in range(args.nb_gpts):
@@ -1027,9 +1004,11 @@ for k in range(args.nb_gpts):
         dropout=args.dropout,
     ).to(main_device)
 
-    model.main_test_accuracy = 0.0
     model.id = k
 
+    model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+    model.main_test_accuracy = 0.0
+
     model.train_w_quizzes = quiz_machine.problem.generate_w_quizzes(
         args.nb_train_samples
     )
@@ -1048,8 +1027,9 @@ if args.resume:
 
         try:
             d = torch.load(os.path.join(args.result_dir, filename))
-            model.load_state_dict(d[0])
-            model.main_test_accuracy = d[1]
+            model.load_state_dict(d["state_dict"])
+            model.optimizer.load_state_dict(d["optimizer_state_dict"])
+            model.main_test_accuracy = d["main_test_accuracy"]
             log_string(f"successfully loaded {filename}")
         except FileNotFoundError:
             log_string(f"cannot find {filename}")
@@ -1305,7 +1285,11 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     for model in weakest_models:
         filename = f"gpt_{model.id:03d}.pth"
         torch.save(
-            (model.state_dict(), model.main_test_accuracy),
+            {
+                "state_dict": model.state_dict(),
+                "optimizer_state_dict": model.optimizer.state_dict(),
+                "main_test_accuracy": model.main_test_accuracy,
+            },
             os.path.join(args.result_dir, filename),
         )
         log_string(f"wrote {filename}")
index 3fc1066..daa9bbf 100755 (executable)
@@ -81,6 +81,7 @@ class QuizMachine:
         self.answer_len = None
         self.prompt_noise = prompt_noise
 
+        # struct, mask_generate, mask_noise
         self.understood_structures = [
             (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)),
             (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)),
@@ -178,15 +179,15 @@ class QuizMachine:
         quizzes, from_w = quizzes[i], from_w[i]
 
         self.randomize_configuations_inplace(
-            quizzes, structs=[s for s, m, _ in self.understood_structures]
+            quizzes, structs=[s for s, _, _ in self.understood_structures]
         )
 
         if self.prompt_noise > 0.0:
-            for struct, mask, noise_mask in self.understood_structures:
+            for struct, _, mask_noise in self.understood_structures:
                 i = self.problem.indices_select(quizzes=quizzes, struct=struct)
                 if i.any():
                     quizzes[i] = self.problem.inject_noise(
-                        quizzes[i], self.prompt_noise, struct=struct, mask=noise_mask
+                        quizzes[i], self.prompt_noise, struct=struct, mask=mask_noise
                     )
 
         return quizzes, from_w
@@ -228,13 +229,15 @@ class QuizMachine:
         nb = 0
 
         # We consider all the configurations that we train for
-        for struct, mask, _ in self.understood_structures:
+        for struct, mask_generate, _ in self.understood_structures:
             i = self.problem.indices_select(quizzes=input, struct=struct)
             nb += i.long().sum()
             result[i], correct[i] = self.predict(
-                model=model, quizzes=input[i], struct=struct, mask=mask
+                model=model, quizzes=input[i], struct=struct, mask=mask_generate
             )
-            predicted_parts[i] = torch.tensor(mask, device=self.device)[None, :]
+            predicted_parts[i] = torch.tensor(mask_generate, device=self.device)[
+                None, :
+            ]
             solution_is_deterministic = predicted_parts[i].sum(dim=-1) == 1
             correct[i] = (2 * correct[i] - 1) * (solution_is_deterministic).long()
 
@@ -329,8 +332,8 @@ class QuizMachine:
         models_for_validation,
         c_quizzes,
         struct,
-        mask,
-        noise_mask=None,
+        mask_value,
+        mask_noise=None,
         device=None,
     ):
         if device is None:
@@ -344,10 +347,10 @@ class QuizMachine:
             device=device,
         )
 
-        if self.prompt_noise > 0.0 and noise_mask is not None:
-            c_quizzes = self.problem.inject_noise(
-                c_quizzes, self.prompt_noise, struct=struct, mask=noise_mask
-            )
+        # if self.prompt_noise > 0.0 and mask_noise is not None:
+        # c_quizzes = self.problem.inject_noise(
+        # c_quizzes, self.prompt_noise, struct=struct, mask=mask_noise
+        # )
 
         for model in models_for_validation:
             with torch.autograd.no_grad():
@@ -359,7 +362,7 @@ class QuizMachine:
                     seq_logproba.split(self.batch_size),
                 ):
                     input = input.to(device)
-                    ar_mask = self.make_ar_mask(input, struct=struct, mask=mask)
+                    ar_mask = self.make_ar_mask(input, struct=struct, mask=mask_value)
                     output = model(mygpt.BracketedSequence(input)).x
                     l[:, model.id] = (
                         -F.cross_entropy(
@@ -374,9 +377,7 @@ class QuizMachine:
 
     ######################################################################
 
-    def generate_c_quizzes(
-        self, nb, model_for_generation, procedure, to_recycle=None, recorder=None
-    ):
+    def generate_c_quizzes(self, nb, model_for_generation, procedure, recorder=None):
         seq_logproba = torch.zeros(nb, device=self.device)
 
         c_quizzes = None
@@ -408,12 +409,6 @@ class QuizMachine:
                     self.problem.reconfigure([x, t], ("A", "f_A", "B", "f_B"))
                 )
 
-            if to_recycle is not None and to_recycle.size(0) > 0:
-                to_recycle = self.problem.reconfigure(to_recycle, s)
-                c_quizzes[: to_recycle.size(0)] = to_recycle
-
-            to_recycle = None
-
         c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B"))
 
         return c_quizzes.to("cpu")