+def compute_valid_quizzes(token_logprobas):
+ l = token_logprobas.sum(dim=-1).sort(dim=-1).values
+ return (l[:, 0] < math.log(args.proba_not_understands)) & (
+ l[:, 1] > math.log(args.proba_understands)
+ )
+
+
+def extract_valid_quizzes_and_logprobas(recorded):
+ validated_quizzes, validated_logprobas = [], []
+ for quizzes, token_logprobas in recorded:
+ validated_indices = compute_valid_quizzes(token_logprobas)
+ validated_quizzes.append(quizzes[validated_indices])
+ validated_logprobas.append(token_logprobas[validated_indices])
+
+ if len(validated_quizzes) > 0:
+ return torch.cat(validated_quizzes, dim=0), torch.cat(
+ validated_logprobas, dim=0
+ )
+ else:
+ return None, None