Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 11 Aug 2024 08:25:49 +0000 (10:25 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 11 Aug 2024 08:25:49 +0000 (10:25 +0200)
grids.py
main.py
quiz_machine.py

index 1a31a36..0564f3b 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -218,7 +218,9 @@ class Grids(problem.Problem):
             dim=1
         ).values
 
-    def make_ar_mask(self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)):
+    def make_quiz_mask(
+        self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)
+    ):
         assert self.check_structure(quizzes, struct)
 
         ar_mask = quizzes.new_zeros(quizzes.size())
diff --git a/main.py b/main.py
index f4691cb..a1389c1 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -502,18 +502,11 @@ def model_transformer_cold(model):
     # pass
 
 
-warnings.warn("*********** novel procedure!!! **********", RuntimeWarning)
-
 c_quizzes_procedure = [
-    # (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot),
-    # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold),
-    # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
     (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot),
-    (("f_B", "f_A", "A", "B"), (0, 1, 1, 0), model_transformer_cold),
-    (("A", "f_A", "B", "f_B"), (0, 0, 1, 1), model_transformer_cold),
-    # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold),
-    # (("f_B", "f_A", "A", "B"), (0, 0, 1, 1), model_transformer_cold),
-    # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
+    (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold),
+    (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
+    (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold),
 ]
 
 ######################################################################
@@ -768,7 +761,7 @@ def generate_c_quizzes_with_generator(generator, quiz_machine, nb):
     struct = ("A", "f_A", "B", "f_B")
 
     c_quizzes = quiz_machine.problem.create_empty_quizzes(nb, struct=struct)
-    ar_mask = quiz_machine.make_ar_mask(c_quizzes, struct, (1, 1, 1, 1))
+    ar_mask = quiz_machine.make_quiz_mask(c_quizzes, struct, (1, 1, 1, 1))
 
     i = F.one_hot(
         torch.randint(args.nb_gpts, (c_quizzes.size(0),)),
index 34abd34..ceb523d 100755 (executable)
@@ -82,19 +82,20 @@ class QuizMachine:
         self.prompt_noise = prompt_noise
 
         # struct, mask_generate, mask_noise, mask_loss
-        self.understood_structures = [
-            (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)),
-            (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
-            (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)),
-            (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
+        self.train_structures = [
+            (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
+            (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
+            (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
+            (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
             (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
+            # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)),
+            # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
+            # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)),
+            # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
+            # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
         ]
 
-        self.test_structures = [
-            self.understood_structures[0],
-            self.understood_structures[2],
-            self.understood_structures[4],
-        ]
+        self.test_structures = self.train_structures
 
         self.LOCK_C_QUIZZES = threading.Lock()
         self.train_c_quizzes = []
@@ -185,13 +186,13 @@ class QuizMachine:
         quizzes, from_w = quizzes[i], from_w[i]
 
         self.randomize_configuations_inplace(
-            quizzes, structs=[s for s, _, _, _ in self.understood_structures]
+            quizzes, structs=[s for s, _, _, _ in self.train_structures]
         )
 
         quiz_mask_loss = quizzes.new_full(quizzes.size(), 1)
 
         if self.prompt_noise > 0.0:
-            for struct, _, mask_noise, mask_loss in self.understood_structures:
+            for struct, _, mask_noise, mask_loss in self.train_structures:
                 i = self.problem.indices_select(quizzes=quizzes, struct=struct)
                 if i.any():
                     quizzes[i] = self.problem.inject_noise(
@@ -206,7 +207,7 @@ class QuizMachine:
     ######################################################################
 
     def make_ar_mask(self, quizzes, struct, mask):
-        assert struct in [s for s, _, _, _ in self.understood_structures]
+        assert struct in [s for s, _, _, _ in self.train_structures]
         return self.problem.make_ar_mask(quizzes, struct=struct, mask=mask)
 
     ######################################################################
@@ -343,7 +344,7 @@ class QuizMachine:
         models_for_validation,
         c_quizzes,
         struct,
-        mask_value,
+        mask_loss,
         mask_noise=None,
         device=None,
     ):
@@ -373,13 +374,15 @@ class QuizMachine:
                     seq_logproba.split(self.batch_size),
                 ):
                     input = input.to(device)
-                    ar_mask = self.make_ar_mask(input, struct=struct, mask=mask_value)
+                    quiz_mask_loss = self.make_ar_mask(
+                        input, struct=struct, mask=mask_loss
+                    )
                     output = model(mygpt.BracketedSequence(input)).x
                     l[:, model.id] = (
                         -F.cross_entropy(
                             output.transpose(1, 2), input, reduction="none"
                         )
-                        * ar_mask
+                        * quiz_mask_loss
                     ).sum(dim=1)
 
                 model.train(t)