sum_logits += c_quizzes.size(0) * ave_seq_logproba
sum_nb_c_quizzes += c_quizzes.size(0)
- nb_correct = quizz_machine.compute_correctness(c_quizzes, models)
+ nb_correct = quizz_machine.compute_correctness(
+ c_quizzes, models, both_direction=True
+ )
if args.dirty_debug:
nb_correct = torch.randint(