c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
- nb_correct, seq_logproba = quiz_machine.compute_correctness(
- c_quizzes,
- models,
- bidirectional_validation=args.bidirectional_validation,
- deterministic_validation=args.deterministic_validation,
- )
+ if c_quizzes.size(0) > 0:
+ nb_correct, seq_logproba = quiz_machine.compute_correctness(
+ c_quizzes,
+ models,
+ bidirectional_validation=args.bidirectional_validation,
+ deterministic_validation=args.deterministic_validation,
+ )
- for n, l in zip(nb_correct, seq_logproba):
- s = " ".join([str(x.item()) for x in l])
- logp_file.write(f"{n} {s}\n")
+ for n, l in zip(nb_correct, seq_logproba):
+ s = " ".join([str(x.item()) for x in l])
+ logp_file.write(f"{n} {s}\n")
- if args.dirty_debug:
- nb_correct = torch.randint(
- len(models) + 1, nb_correct.size(), device=c_quizzes.device
- )
+ if args.dirty_debug:
+ nb_correct = torch.randint(
+ len(models) + 1, nb_correct.size(), device=c_quizzes.device
+ )
- quizzes_and_nb_correct_records.append((c_quizzes, nb_correct))
+ quizzes_and_nb_correct_records.append((c_quizzes, nb_correct))
nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0)
nv = " ".join([str(x.item()) for x in nv])