From 7127285e1822ede216e0f967488a4545e7f6f958 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 11 Aug 2024 10:30:43 +0200 Subject: [PATCH] Update. --- main.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/main.py b/main.py index a1389c1..40772c2 100755 --- a/main.py +++ b/main.py @@ -107,8 +107,6 @@ parser.add_argument("--temperature_cold", type=float, default=1) parser.add_argument("--prompt_noise", type=float, default=0.05) -parser.add_argument("--nb_averaging_rounds", type=int, default=3) - parser.add_argument("--dirty_debug", action="store_true", default=False) parser.add_argument("--test", type=str, default=None) @@ -719,20 +717,13 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]] if vq.size(0) > 0: - probas = 0 - - for a in range(args.nb_averaging_rounds): - # This is nb_quizzes x nb_models - - seq_logproba = quiz_machine.models_logprobas( - models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) + quiz_machine.models_logprobas( - models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) - - probas += seq_logproba.exp() + seq_logproba = quiz_machine.models_logprobas( + models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + quiz_machine.models_logprobas( + models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) - probas /= args.nb_averaging_rounds + probas = seq_logproba.exp() comments = [] -- 2.39.5