######################################################################
+def identity_quizzes(quizzes):
+ quizzes = quizzes.reshape(quizzes.size(0), 4, -1)
+ return (quizzes[:, 0] == quizzes[:, 1]).min(dim=1).values & (
+ quizzes[:, 2] == quizzes[:, 3]
+ ).min(dim=1).values
+
+
def generate_c_quizzes(models, nb_to_generate, local_device=main_device):
record = []
nb_validated = 0
model=model, nb=args.eval_batch_size * 10, local_device=local_device
)
- # Select the ones that are solved properly by some models and
- # not understood by others
+ c_quizzes = c_quizzes[identity_quizzes(c_quizzes) == False]
- to_keep, nb_correct, nb_wrong = evaluate_quizzes(
- quizzes=c_quizzes,
- models=models,
- fraction_with_hints=1.0,
- local_device=local_device,
- )
+ if c_quizzes.size(0) > 0:
+ # Select the ones that are solved properly by some models and
+ # not understood by others
+
+ to_keep, nb_correct, nb_wrong = evaluate_quizzes(
+ quizzes=c_quizzes,
+ models=models,
+ fraction_with_hints=1.0,
+ local_device=local_device,
+ )
- nb_validated += to_keep.long().sum().item()
- record.append(c_quizzes[to_keep])
+ nb_validated += to_keep.long().sum().item()
+ record.append(c_quizzes[to_keep])
#####################