Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 16 Jul 2024 06:08:28 +0000 (08:08 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 16 Jul 2024 06:08:28 +0000 (08:08 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 5c58beb..b149e62 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -88,10 +88,12 @@ parser.add_argument("--proba_understands", type=float, default=0.9)
 
 parser.add_argument("--proba_not_understands", type=float, default=0.5)
 
-parser.add_argument("--generation_temperature", type=float, default=2)
+parser.add_argument("--generation_temperature", type=float, default=2)
 
 parser.add_argument("--c_quiz_validation_mode", type=str, default="predict")
 
+parser.add_argument("--forward_only", action="store_true", default=False)
+
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
 ######################################################################
@@ -411,10 +413,10 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
     while nb_validated < nb_to_create:
         model_for_generation = models[torch.randint(len(models), (1,))]
 
-        c_quizzes = quiz_machine.generate_quizzes(
+        c_quizzes = quiz_machine.generate_c_quizzes(
             nb_to_generate_per_iteration,
             model_for_generation=model_for_generation,
-            temperature=args.generation_temperature,
+            forward_only=args.forward_only,
         )
 
         c_quizzes = keep_good_quizzes(models, c_quizzes)
@@ -482,10 +484,19 @@ for k in range(args.nb_gpts):
     model.main_test_accuracy = 0.0
     model.id = k
 
-    model.train_w_quizzes = quiz_machine.generate_token_sequences(args.nb_train_samples)
-    quiz_machine.reverse_random_half_in_place(model.train_w_quizzes)
-    model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples)
-    quiz_machine.reverse_random_half_in_place(model.test_w_quizzes)
+    quiz_machine.create_w_quizzes(
+        model=model,
+        nb=args.nb_train_samples,
+        for_train=True,
+        forward_only=args.forward_only,
+    )
+
+    quiz_machine.create_w_quizzes(
+        model=model,
+        nb=args.nb_test_samples,
+        for_train=False,
+        forward_only=args.forward_only,
+    )
 
     models.append(model)
 
@@ -659,7 +670,11 @@ for n_epoch in range(args.nb_epochs):
     # Renew the training samples
 
     for model in weakest_models:
-        quiz_machine.renew_w_quizzes(model, args.nb_train_samples)
-
+        quiz_machine.renew_w_quizzes(
+            model=model,
+            nb=args.nb_train_samples,
+            for_train=True,
+            forward_only=args.forward_only,
+        )
 
 ######################################################################
index 0f834dc..008e435 100755 (executable)
@@ -428,12 +428,26 @@ class QuizMachine:
 
     ######################################################################
 
-    def renew_w_quizzes(self, model, nb, for_train=True):
+    def create_w_quizzes(self, model, nb, for_train=True, forward_only=False):
+        input = self.generate_token_sequences(nb)
+
+        if not forward_only:
+            self.reverse_random_half_in_place(input)
+
+        if for_train:
+            model.train_w_quizzes = input
+        else:
+            model.test_w_quizzes = input
+
+    ######################################################################
+
+    def renew_w_quizzes(self, model, nb, for_train=True, forward_only=False):
         input = model.train_w_quizzes if for_train else model.test_w_quizzes
         nb = min(nb, input.size(0))
         input[:-nb] = input[nb:].clone()
         fresh_w_quizzes = self.generate_token_sequences(nb)
-        self.reverse_random_half_in_place(fresh_w_quizzes)
+        if not forward_only:
+            self.reverse_random_half_in_place(fresh_w_quizzes)
         input[-nb:] = fresh_w_quizzes.to("cpu")
 
     ######################################################################
@@ -527,7 +541,7 @@ class QuizMachine:
 
     ###############################################################
 
-    def generate_quizzes(self, nb, model_for_generation, temperature=1.0):
+    def generate_c_quizzes(self, nb, model_for_generation, forward_only=False):
         c_quizzes = torch.empty(
             nb,
             self.prompt_len + self.answer_len + 2,
@@ -537,29 +551,69 @@ class QuizMachine:
 
         seq_logproba = torch.zeros(nb, device=self.device)
 
-        c_quizzes[:, 0] = self.token_forward
-        c_quizzes[:, 1 + self.prompt_len] = self.token_forward
-
-        masked_inplace_autoregression(
-            model=model_for_generation,
-            batch_size=self.batch_size,
-            input=c_quizzes,
-            ar_mask=self.make_ar_mask(c_quizzes, first=True),
-            seq_logproba=seq_logproba,
-            temperature=1.0,
-            deterministic_synthesis=False,
-            device=self.device,
-        )
+        if forward_only:
+            c_quizzes[:, 0] = self.token_forward
+            c_quizzes[:, 1 + self.prompt_len] = self.token_forward
 
-        masked_inplace_autoregression(
-            model=model_for_generation,
-            batch_size=self.batch_size,
-            input=c_quizzes,
-            ar_mask=self.make_ar_mask(c_quizzes),
-            seq_logproba=seq_logproba,
-            temperature=1,
-            deterministic_synthesis=False,
-            device=self.device,
-        )
+            masked_inplace_autoregression(
+                model=model_for_generation,
+                batch_size=self.batch_size,
+                input=c_quizzes,
+                ar_mask=self.make_ar_mask(c_quizzes, first=True),
+                seq_logproba=seq_logproba,
+                temperature=1.0,
+                deterministic_synthesis=False,
+                device=self.device,
+            )
+
+            masked_inplace_autoregression(
+                model=model_for_generation,
+                batch_size=self.batch_size,
+                input=c_quizzes,
+                ar_mask=self.make_ar_mask(c_quizzes),
+                seq_logproba=seq_logproba,
+                temperature=1,
+                deterministic_synthesis=False,
+                device=self.device,
+            )
+
+        else:
+            c_quizzes[:, 0] = self.token_backward
+            c_quizzes[:, 1 + self.answer_len] = self.token_backward
+
+            masked_inplace_autoregression(
+                model=model_for_generation,
+                batch_size=self.batch_size,
+                input=c_quizzes,
+                ar_mask=self.make_ar_mask(c_quizzes, first=True),
+                seq_logproba=seq_logproba,
+                temperature=1.0,
+                deterministic_synthesis=False,
+                device=self.device,
+            )
+
+            masked_inplace_autoregression(
+                model=model_for_generation,
+                batch_size=self.batch_size,
+                input=c_quizzes,
+                ar_mask=self.make_ar_mask(c_quizzes),
+                seq_logproba=seq_logproba,
+                temperature=1,
+                deterministic_synthesis=False,
+                device=self.device,
+            )
+
+            c_quizzes = self.reverse_time(c_quizzes)
+
+            masked_inplace_autoregression(
+                model=model_for_generation,
+                batch_size=self.batch_size,
+                input=c_quizzes,
+                ar_mask=self.make_ar_mask(c_quizzes),
+                seq_logproba=seq_logproba,
+                temperature=1,
+                deterministic_synthesis=False,
+                device=self.device,
+            )
 
         return c_quizzes.to("cpu")