+ sum_logits += new_c_quizzes.size(0) * ave_seq_logproba
+ sum_nb_c_quizzes += new_c_quizzes.size(0)
+
+ if args.dirty_debug:
+ nb_correct = torch.randint(
+ len(models) + 1, nb_correct.size(), device=new_c_quizzes.device
+ )
+
+ for n in range(nb_correct.max() + 1):
+ recorded[n].append(new_c_quizzes[nb_correct == n].clone())
+
+ nb_validated = sum([x.size(0) for x in recorded[nb_correct_to_validate]])
+ nb_generated = sum(
+ [sum([x.size(0) for x in recorded[n]]) for n in recorded.keys()]
+ )