Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 11 Aug 2024 08:30:43 +0000 (10:30 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 11 Aug 2024 08:30:43 +0000 (10:30 +0200)
main.py

diff --git a/main.py b/main.py
index a1389c1..40772c2 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -107,8 +107,6 @@ parser.add_argument("--temperature_cold", type=float, default=1)
 
 parser.add_argument("--prompt_noise", type=float, default=0.05)
 
-parser.add_argument("--nb_averaging_rounds", type=int, default=3)
-
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
 parser.add_argument("--test", type=str, default=None)
@@ -719,20 +717,13 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
     vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]]
 
     if vq.size(0) > 0:
-        probas = 0
-
-        for a in range(args.nb_averaging_rounds):
-            # This is nb_quizzes x nb_models
-
-            seq_logproba = quiz_machine.models_logprobas(
-                models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
-            ) + quiz_machine.models_logprobas(
-                models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
-            )
-
-            probas += seq_logproba.exp()
+        seq_logproba = quiz_machine.models_logprobas(
+            models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+        ) + quiz_machine.models_logprobas(
+            models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+        )
 
-        probas /= args.nb_averaging_rounds
+        probas = seq_logproba.exp()
 
         comments = []