Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 29 Jul 2024 07:49:25 +0000 (09:49 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 29 Jul 2024 07:49:25 +0000 (09:49 +0200)
quiz_machine.py

index 34f6b62..ca71c95 100755 (executable)
@@ -444,40 +444,45 @@ class QuizMachine:
 
     ###############################################################
 
-    def optimize_quizzes(self, quiz, nb_variants, nb_iterations, struct, mask):
+    def optimize_quizzes(
+        self, models, quiz, nb_variants, nb_iterations, struct, mask, proba_understands
+    ):
         for _ in range(nb_iterations):
-            candidates = quizzes[None].expand(nb_variants, -1)
+            candidates = quiz[None, :].expand(nb_variants, -1).clone()
             r = torch.rand(candidates.size(), device=candidates.device)
             u = r.reshape(r.size(0), 4, candidates.size(1) // 4)
             # Only change the part indicated by the mask and do not
             # touch the special tokens
             u[:, :, 0] = 0
-            u = u * torch.tensor(mask, device=u.device)[None, :, None]
-            random_mask = (r.sort(dim=0, descending=True).indices == 0).long()
+            u = u * (1 - torch.tensor(mask, device=u.device)[None, :, None])
+            random_mask = F.one_hot(r.argmax(dim=1), num_classes=r.size(1))
             # Keep the first unchanged
-            random_mask[:, 0, :] = 0
+            random_mask[0, :] = 0
             # Reshape without the 4 parts
             candidates.reshape(-1, candidates.size(-1))
             random_mask.reshape(candidates.size())
             random_tokens = torch.randint(
-                self.problem.nb_token_values - 4, random_mask.size()
+                self.problem.nb_token_values - 4,
+                random_mask.size(),
+                device=candidates.device,
             )
             # Apply the noise
             candidates = (1 - random_mask) * candidates + random_mask * random_tokens
-            seq_logproba = quiz_machine.models_logprobas(
+            seq_logproba = self.models_logprobas(
                 models, candidates, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
-            ) + quiz_machine.models_logprobas(
+            ) + self.models_logprobas(
                 models, candidates, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
             )
             sorted_logprobas = seq_logproba.sort(dim=1).values.exp()
             lowest, second_lowest = sorted_logprobas[:, 0], sorted_logprobas[:, 1]
             score = second_lowest - lowest
 
-            score = score * (second_lowest > args.proba_understands)
+            # score = score * (second_lowest > proba_understands)
 
             quiz = candidates[score.argmax()]
+            print(score.max())
 
-        return quiz
+        return quiz.to("cpu")
 
     def generate_c_quizzes(self, nb, model_for_generation, procedure, to_recycle=None):
         seq_logproba = torch.zeros(nb, device=self.device)