Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 17 Jul 2024 03:14:05 +0000 (05:14 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 17 Jul 2024 03:14:05 +0000 (05:14 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index ca1e9b5..76db5e2 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -82,13 +82,13 @@ parser.add_argument("--gpus", type=str, default="all")
 
 parser.add_argument("--nb_gpts", type=int, default=5)
 
-parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9)
+parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975)
 
 parser.add_argument("--proba_understands", type=float, default=0.9)
 
 parser.add_argument("--proba_not_understands", type=float, default=0.5)
 
-parser.add_argument("--generation_temperature", type=float, default=2)
+parser.add_argument("--generation_temperature", type=float, default=2.5)
 
 parser.add_argument("--c_quiz_validation_mode", type=str, default="predict")
 
@@ -367,9 +367,12 @@ def one_epoch(model, quiz_machine, local_device=main_device):
         acc_train_loss += loss.item() * input.size(0)
 
         loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1)
-        hard_w_quizzes.append(
-            (input[from_w].to("cpu"), loss_per_samples[from_w].to("cpu"))
-        )
+        n_forward = input[:, 0] == self.token_forward
+        to_store = from_w & n_forward
+        if to_store.any():
+            hard_w_quizzes.append(
+                (input[to_store].to("cpu"), loss_per_samples[to_store].to("cpu"))
+            )
 
         nb_train_samples += input.size(0)
 
@@ -384,11 +387,11 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
     run_tests(model, quiz_machine, deterministic_synthesis=False)
 
-    threshold = torch.cat([x[1] for x in hard_w_quizzes], dim=0).sort().values
+    threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values
     threshold = threshold[threshold.size(0) // 2]
 
     model.hard_w_quizzes = torch.cat(
-        [x[0][x[1] >= threshold] for x in hard_w_quizzes], dim=0
+        [x[l >= threshold] for x, l in hard_w_quizzes], dim=0
     )
 
     model.to(main_device)
index 32b3f7e..1168921 100755 (executable)
@@ -350,9 +350,7 @@ class QuizMachine:
 
     ######################################################################
 
-    def produce_results(
-        self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
-    ):
+    def produce_results(self, n_epoch, model, result_dir, deterministic_synthesis):
         def compute_accuracy(input, log_prefix=None):
             input = input.to(self.device)
             ar_mask = self.make_ar_mask(input)
@@ -400,14 +398,15 @@ class QuizMachine:
 
             return result, correct
 
-        # compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train")
-
         test_result, test_correct = compute_accuracy(
-            model.test_w_quizzes[:nmax], log_prefix="test"
+            model.test_w_quizzes[:2000], log_prefix="test"
         )
 
-        main_test_accuracy = test_correct.sum() / test_correct.size(0)
-        # self.logger(f"main_test_accuracy {n_epoch} model {model.id} {main_test_accuracy}")
+        n_test_forward = model.test_w_quizzes[:, 0] == self.token_forward
+
+        forward_test_correct = test_correct[n_test_forward]
+
+        main_test_accuracy = forward_test_correct.sum() / forward_test_correct.size(0)
 
         ##############################
 
@@ -459,6 +458,9 @@ class QuizMachine:
         else:
             input[...] = self.generate_token_sequences(input.size(0))
 
+        if not forward_only:
+            self.reverse_random_half_in_place(input)
+
     ######################################################################
 
     def store_c_quizzes(self, new_c_quizzes, for_train=True):