parser.add_argument("--prompt_noise", type=float, default=0.05)
-parser.add_argument("--nb_averaging_rounds", type=int, default=3)
-
parser.add_argument("--dirty_debug", action="store_true", default=False)
parser.add_argument("--test", type=str, default=None)
vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]]
if vq.size(0) > 0:
- probas = 0
-
- for a in range(args.nb_averaging_rounds):
- # This is nb_quizzes x nb_models
-
- seq_logproba = quiz_machine.models_logprobas(
- models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
- ) + quiz_machine.models_logprobas(
- models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
- )
-
- probas += seq_logproba.exp()
+ seq_logproba = quiz_machine.models_logprobas(
+ models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ ) + quiz_machine.models_logprobas(
+ models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ )
- probas /= args.nb_averaging_rounds
+ probas = seq_logproba.exp()
comments = []