X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=main.py;h=0a7be99f3cbf6dca2131fe512c2b9a3304e0b2d8;hb=240870f5535bac35a08c552108d032854a8e2c38;hp=eb0ef27fb28c334ac7d7f55545f3333930a94238;hpb=fcb71a73da3a27f81383e3000b9ad1ee8da45926;p=culture.git diff --git a/main.py b/main.py index eb0ef27..0a7be99 100755 --- a/main.py +++ b/main.py @@ -417,7 +417,9 @@ def create_c_quizzes( sum_logits += c_quizzes.size(0) * ave_seq_logproba sum_nb_c_quizzes += c_quizzes.size(0) - nb_correct = quizz_machine.compute_correctness(c_quizzes, models) + nb_correct = quizz_machine.compute_correctness( + c_quizzes, models, both_direction=True + ) if args.dirty_debug: nb_correct = torch.randint(