From 7c2a7ea9870de944538443423d8cf60cc6d74d4e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 6 Jul 2024 22:56:15 +0300 Subject: [PATCH] Update. --- main.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/main.py b/main.py index 00a9492..4ff50d7 100755 --- a/main.py +++ b/main.py @@ -423,23 +423,24 @@ def create_c_quizzes( c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] - nb_correct, seq_logproba = quiz_machine.compute_correctness( - c_quizzes, - models, - bidirectional_validation=args.bidirectional_validation, - deterministic_validation=args.deterministic_validation, - ) + if c_quizzes.size(0) > 0: + nb_correct, seq_logproba = quiz_machine.compute_correctness( + c_quizzes, + models, + bidirectional_validation=args.bidirectional_validation, + deterministic_validation=args.deterministic_validation, + ) - for n, l in zip(nb_correct, seq_logproba): - s = " ".join([str(x.item()) for x in l]) - logp_file.write(f"{n} {s}\n") + for n, l in zip(nb_correct, seq_logproba): + s = " ".join([str(x.item()) for x in l]) + logp_file.write(f"{n} {s}\n") - if args.dirty_debug: - nb_correct = torch.randint( - len(models) + 1, nb_correct.size(), device=c_quizzes.device - ) + if args.dirty_debug: + nb_correct = torch.randint( + len(models) + 1, nb_correct.size(), device=c_quizzes.device + ) - quizzes_and_nb_correct_records.append((c_quizzes, nb_correct)) + quizzes_and_nb_correct_records.append((c_quizzes, nb_correct)) nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0) nv = " ".join([str(x.item()) for x in nv]) -- 2.20.1