From: François Fleuret Date: Sun, 11 Aug 2024 08:30:43 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=7127285e1822ede216e0f967488a4545e7f6f958;p=culture.git Update. --- diff --git a/main.py b/main.py index a1389c1..40772c2 100755 --- a/main.py +++ b/main.py @@ -107,8 +107,6 @@ parser.add_argument("--temperature_cold", type=float, default=1) parser.add_argument("--prompt_noise", type=float, default=0.05) -parser.add_argument("--nb_averaging_rounds", type=int, default=3) - parser.add_argument("--dirty_debug", action="store_true", default=False) parser.add_argument("--test", type=str, default=None) @@ -719,20 +717,13 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]] if vq.size(0) > 0: - probas = 0 - - for a in range(args.nb_averaging_rounds): - # This is nb_quizzes x nb_models - - seq_logproba = quiz_machine.models_logprobas( - models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) + quiz_machine.models_logprobas( - models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) - - probas += seq_logproba.exp() + seq_logproba = quiz_machine.models_logprobas( + models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + quiz_machine.models_logprobas( + models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) - probas /= args.nb_averaging_rounds + probas = seq_logproba.exp() comments = []