Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 10 Aug 2024 22:58:46 +0000 (00:58 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 10 Aug 2024 22:58:46 +0000 (00:58 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 0670262..f4691cb 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -390,13 +390,26 @@ def run_tests(model, quiz_machine, local_device=main_device):
         nb_test_samples, acc_test_loss = 0, 0.0
         nb_samples_accumulated = 0
 
-        full_input, _ = quiz_machine.data_input(model, split="test")
-        src = full_input.split(args.batch_size)
+        full_input, full_mask_loss = quiz_machine.data_input(model, split="test")
+        src = zip(
+            full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
+        )
 
-        for input in tqdm.tqdm(src, dynamic_ncols=True, desc="test"):
+        for input, mask_loss in tqdm.tqdm(
+            src,
+            dynamic_ncols=True,
+            desc="test",
+            total=full_input.size(0) // args.batch_size,
+        ):
             input = input.to(local_device)
+            mask_loss = mask_loss.to(local_device)
+            targets = input
+
             output = model(mygpt.BracketedSequence(input)).x
-            loss = F.cross_entropy(output.transpose(1, 2), input)
+            loss_per_token = F.cross_entropy(
+                output.transpose(1, 2), targets, reduction="none"
+            )
+            loss = (loss_per_token * mask_loss).mean()
             acc_test_loss += loss.item() * input.size(0)
             nb_test_samples += input.size(0)
 
@@ -426,16 +439,17 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
     hard_w_quizzes = []
 
-    full_input, full_from_w = quiz_machine.data_input(model, split="train")
-    src = zip(full_input.split(args.batch_size), full_from_w.split(args.batch_size))
+    full_input, full_mask_loss = quiz_machine.data_input(model, split="train")
+    src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size))
 
-    for input, from_w in tqdm.tqdm(
+    for input, mask_loss in tqdm.tqdm(
         src,
         dynamic_ncols=True,
         desc="training",
         total=full_input.size(0) // args.batch_size,
     ):
         input = input.to(local_device)
+        mask_loss = mask_loss.to(local_device)
 
         if nb_train_samples % args.batch_size == 0:
             model.optimizer.zero_grad()
@@ -446,14 +460,10 @@ def one_epoch(model, quiz_machine, local_device=main_device):
         loss_per_token = F.cross_entropy(
             output.transpose(1, 2), targets, reduction="none"
         )
-        loss = loss_per_token.mean() + model.loss
+        loss = (loss_per_token * mask_loss).mean() + model.loss
         acc_train_loss += loss.item() * input.size(0)
 
         loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1)
-        if from_w.any():
-            hard_w_quizzes.append(
-                (input[from_w].to("cpu"), loss_per_samples[from_w].to("cpu"))
-            )
 
         nb_train_samples += input.size(0)
 
index daa9bbf..34abd34 100755 (executable)
@@ -81,13 +81,19 @@ class QuizMachine:
         self.answer_len = None
         self.prompt_noise = prompt_noise
 
-        # struct, mask_generate, mask_noise
+        # struct, mask_generate, mask_noise, mask_loss
         self.understood_structures = [
-            (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)),
-            (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)),
-            (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 0, 0)),
-            (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 0, 0)),
-            (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0)),
+            (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)),
+            (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
+            (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)),
+            (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
+            (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
+        ]
+
+        self.test_structures = [
+            self.understood_structures[0],
+            self.understood_structures[2],
+            self.understood_structures[4],
         ]
 
         self.LOCK_C_QUIZZES = threading.Lock()
@@ -179,23 +185,28 @@ class QuizMachine:
         quizzes, from_w = quizzes[i], from_w[i]
 
         self.randomize_configuations_inplace(
-            quizzes, structs=[s for s, _, _ in self.understood_structures]
+            quizzes, structs=[s for s, _, _, _ in self.understood_structures]
         )
 
+        quiz_mask_loss = quizzes.new_full(quizzes.size(), 1)
+
         if self.prompt_noise > 0.0:
-            for struct, _, mask_noise in self.understood_structures:
+            for struct, _, mask_noise, mask_loss in self.understood_structures:
                 i = self.problem.indices_select(quizzes=quizzes, struct=struct)
                 if i.any():
                     quizzes[i] = self.problem.inject_noise(
                         quizzes[i], self.prompt_noise, struct=struct, mask=mask_noise
                     )
+                    quiz_mask_loss[i] = self.make_ar_mask(
+                        quizzes=quizzes[i], struct=struct, mask=mask_loss
+                    )
 
-        return quizzes, from_w
+        return quizzes, quiz_mask_loss
 
     ######################################################################
 
     def make_ar_mask(self, quizzes, struct, mask):
-        assert struct in [s for s, _, _ in self.understood_structures]
+        assert struct in [s for s, _, _, _ in self.understood_structures]
         return self.problem.make_ar_mask(quizzes, struct=struct, mask=mask)
 
     ######################################################################
@@ -229,7 +240,7 @@ class QuizMachine:
         nb = 0
 
         # We consider all the configurations that we train for
-        for struct, mask_generate, _ in self.understood_structures:
+        for struct, mask_generate, _, _ in self.test_structures:
             i = self.problem.indices_select(quizzes=input, struct=struct)
             nb += i.long().sum()
             result[i], correct[i] = self.predict(