return torch.cat(record)
-def predict_full(model, input, fraction_with_hints=0.0, local_device=main_device):
+def predict_full(model, input, fraction_with_hints, local_device=main_device):
input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1))
nb = input.size(0)
masks = input.new_zeros(input.size())
######################################################################
-def evaluate_quizzes(c_quizzes, models, local_device):
+def evaluate_quizzes(c_quizzes, models, fraction_with_hints, local_device):
nb_correct, nb_wrong = 0, 0
for model in models:
model = copy.deepcopy(model).to(local_device).eval()
- result = predict_full(model, c_quizzes, local_device=local_device)
+ result = predict_full(
+ model=model,
+ quizzes=c_quizzes,
+ fraction_with_hints=fraction_with_hints,
+ local_device=local_device,
+ )
nb_mistakes = (result != c_quizzes).long().sum(dim=1)
nb_correct += (nb_mistakes == 0).long()
nb_wrong += nb_mistakes >= args.nb_mistakes_to_be_wrong
# not understood by others
to_keep, nb_correct, nb_wrong = evaluate_quizzes(
- c_quizzes, models, local_device
+ quizzes=c_quizzes,
+ models=models,
+ fraction_with_hints=1.0,
+ local_device=local_device,
)
nb_validated += to_keep.long().sum().item()