From: François Fleuret Date: Sat, 6 Jul 2024 19:56:15 +0000 (+0300) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;ds=inline;h=7c2a7ea9870de944538443423d8cf60cc6d74d4e;p=culture.git Update. --- diff --git a/main.py b/main.py index 00a9492..4ff50d7 100755 --- a/main.py +++ b/main.py @@ -423,23 +423,24 @@ def create_c_quizzes( c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] - nb_correct, seq_logproba = quiz_machine.compute_correctness( - c_quizzes, - models, - bidirectional_validation=args.bidirectional_validation, - deterministic_validation=args.deterministic_validation, - ) + if c_quizzes.size(0) > 0: + nb_correct, seq_logproba = quiz_machine.compute_correctness( + c_quizzes, + models, + bidirectional_validation=args.bidirectional_validation, + deterministic_validation=args.deterministic_validation, + ) - for n, l in zip(nb_correct, seq_logproba): - s = " ".join([str(x.item()) for x in l]) - logp_file.write(f"{n} {s}\n") + for n, l in zip(nb_correct, seq_logproba): + s = " ".join([str(x.item()) for x in l]) + logp_file.write(f"{n} {s}\n") - if args.dirty_debug: - nb_correct = torch.randint( - len(models) + 1, nb_correct.size(), device=c_quizzes.device - ) + if args.dirty_debug: + nb_correct = torch.randint( + len(models) + 1, nb_correct.size(), device=c_quizzes.device + ) - quizzes_and_nb_correct_records.append((c_quizzes, nb_correct)) + quizzes_and_nb_correct_records.append((c_quizzes, nb_correct)) nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0) nv = " ".join([str(x.item()) for x in nv])