Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 1 Aug 2024 17:36:19 +0000 (19:36 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 1 Aug 2024 17:36:19 +0000 (19:36 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 526da6f..94c030a 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -581,9 +581,9 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
         # This is nb_quizzes x nb_models
 
         seq_logproba = quiz_machine.models_logprobas(
-            models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+            models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
         ) + quiz_machine.models_logprobas(
-            models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
+            models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
         )
 
         probas = seq_logproba.exp()
@@ -648,9 +648,9 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
     if vq.size(0) > 0:
         seq_logproba = quiz_machine.models_logprobas(
-            models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+            models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
         ) + quiz_machine.models_logprobas(
-            models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
+            models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
         )
 
         comments = []
@@ -753,9 +753,17 @@ def batches_for_generator(generator, quiz_machine, models, fraction_w_quizzes=1.
 
             if c_quizzes.size(0) > 0:
                 seq_logproba = quiz_machine.models_logprobas(
-                    models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+                    models,
+                    c_quizzes,
+                    ("A", "f_A", "B", "f_B"),
+                    (0, 0, 0, 1),
+                    (0, 0, 1, 0),
                 ) + quiz_machine.models_logprobas(
-                    models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
+                    models,
+                    c_quizzes,
+                    ("f_A", "A", "f_B", "B"),
+                    (0, 0, 0, 1),
+                    (0, 0, 1, 0),
                 )
 
                 probas = seq_logproba.exp()
@@ -1075,9 +1083,9 @@ if args.test_generator:
         )
 
         seq_logproba = quiz_machine.models_logprobas(
-            models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+            models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
         ) + quiz_machine.models_logprobas(
-            models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
+            models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
         )
 
         probas = seq_logproba.exp()
index b7c3b09..0fdfbf6 100755 (executable)
@@ -325,13 +325,24 @@ class QuizMachine:
     ######################################################################
 
     def models_logprobas(
-        self, models_for_validation, c_quizzes, struct, mask, device=None
+        self,
+        models_for_validation,
+        c_quizzes,
+        struct,
+        mask,
+        noise_mask=None,
+        device=None,
     ):
         if device is None:
             device = self.device
 
         c_quizzes = self.problem.reconfigure(c_quizzes, struct)
 
+        if self.prompt_noise > 0.0 and noise_mask is not None:
+            c_quizzes = self.problem.inject_noise(
+                c_quizzes, self.prompt_noise, struct=struct, mask=noise_mask
+            )
+
         seq_logproba = torch.zeros(
             c_quizzes.size(0),
             max([m.id for m in models_for_validation]) + 1,