Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 6 Jul 2024 19:56:15 +0000 (22:56 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 6 Jul 2024 19:56:15 +0000 (22:56 +0300)
main.py

diff --git a/main.py b/main.py
index 00a9492..4ff50d7 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -423,23 +423,24 @@ def create_c_quizzes(
 
             c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
 
-            nb_correct, seq_logproba = quiz_machine.compute_correctness(
-                c_quizzes,
-                models,
-                bidirectional_validation=args.bidirectional_validation,
-                deterministic_validation=args.deterministic_validation,
-            )
+            if c_quizzes.size(0) > 0:
+                nb_correct, seq_logproba = quiz_machine.compute_correctness(
+                    c_quizzes,
+                    models,
+                    bidirectional_validation=args.bidirectional_validation,
+                    deterministic_validation=args.deterministic_validation,
+                )
 
-            for n, l in zip(nb_correct, seq_logproba):
-                s = " ".join([str(x.item()) for x in l])
-                logp_file.write(f"{n} {s}\n")
+                for n, l in zip(nb_correct, seq_logproba):
+                    s = " ".join([str(x.item()) for x in l])
+                    logp_file.write(f"{n} {s}\n")
 
-            if args.dirty_debug:
-                nb_correct = torch.randint(
-                    len(models) + 1, nb_correct.size(), device=c_quizzes.device
-                )
+                if args.dirty_debug:
+                    nb_correct = torch.randint(
+                        len(models) + 1, nb_correct.size(), device=c_quizzes.device
+                    )
 
-            quizzes_and_nb_correct_records.append((c_quizzes, nb_correct))
+                quizzes_and_nb_correct_records.append((c_quizzes, nb_correct))
 
             nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0)
             nv = " ".join([str(x.item()) for x in nv])