Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 22 Aug 2024 16:23:52 +0000 (18:23 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 22 Aug 2024 16:23:52 +0000 (18:23 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 35ba763..fc480b7 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -846,7 +846,13 @@ def test_ae(local_device=main_device):
         model.train()
         nb_train_samples, acc_train_loss = 0, 0.0
 
-        full_input, full_mask_loss = quiz_machine.data_input(args.nb_train_samples)
+        data_structures = [
+            (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
+        ]
+
+        full_input, full_mask_loss = quiz_machine.data_input(
+            args.nb_train_samples, data_structures=data_structures
+        )
 
         src = zip(
             full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
@@ -866,6 +872,7 @@ def test_ae(local_device=main_device):
 
             targets = input
             input = (mask_loss == 0).long() * input
+
             output = model(mygpt.BracketedSequence(input)).x
             loss = F.cross_entropy(output.transpose(1, 2), targets)
             acc_train_loss += loss.item() * input.size(0)
index ceb527a..08f121a 100755 (executable)
@@ -140,7 +140,12 @@ class QuizMachine:
 
     ######################################################################
 
-    def data_input(self, nb_samples, c_quiz_bags=[], c_quiz_multiplier=1):
+    def data_input(
+        self, nb_samples, c_quiz_bags=[], c_quiz_multiplier=1, data_structures=None
+    ):
+        if data_structures is None:
+            data_structures = self.train_structures
+
         if len(c_quiz_bags) > 0:
             c_quizzes = torch.cat(c_quiz_bags, dim=0)
 
@@ -170,21 +175,24 @@ class QuizMachine:
         quizzes = quizzes[i]
 
         self.randomize_configuations_inplace(
-            quizzes, structs=[s for s, _, _, _ in self.train_structures]
+            quizzes, structs=[s for s, _, _, _ in data_structures]
         )
 
         quiz_mask_loss = quizzes.new_full(quizzes.size(), 1)
 
-        if self.prompt_noise > 0.0:
-            for struct, _, quad_noise, quad_loss in self.train_structures:
-                i = self.problem.indices_select(quizzes=quizzes, struct=struct)
-                if i.any():
+        for struct, _, quad_noise, quad_loss in data_structures:
+            i = self.problem.indices_select(quizzes=quizzes, struct=struct)
+            if i.any():
+                if self.prompt_noise > 0.0:
                     quizzes[i] = self.problem.inject_noise(
                         quizzes[i], self.prompt_noise, struct=struct, quad=quad_noise
                     )
-                    quiz_mask_loss[i] = self.make_quiz_mask(
-                        quizzes=quizzes[i], struct=struct, quad=quad_loss
-                    )
+                quiz_mask_loss[i] = self.make_quiz_mask(
+                    quizzes=quizzes[i], struct=struct, quad=quad_loss
+                )
+
+        print("quad_loss", quad_loss)
+        print("quiz_mask_loss", quiz_mask_loss)
 
         return quizzes, quiz_mask_loss