return (probas.max(dim=1).values >= 0.75) & (probas.min(dim=1).values <= 0.25)
+def c_quiz_criterion_one_good_no_very_bad(probas):
+ return (
+ (probas.max(dim=1).values >= 0.75)
+ & (probas.min(dim=1).values <= 0.75)
+ & (probas.min(dim=1).values >= 0.25)
+ )
+
+
def c_quiz_criterion_diff(probas):
return (probas.max(dim=1).values - probas.min(dim=1).values) >= 0.5
return (v[:, -2] - v[:, 0]) >= 0.5
+def c_quiz_criterion_only_one(probas):
+ v = probas.sort(dim=1).values
+ return (v[:, -1] >= 0.75) & (v[:, -2] <= 0.25)
+
+
def c_quiz_criterion_two_good(probas):
return ((probas >= 0.5).long().sum(dim=1) >= 2) & (probas.min(dim=1).values <= 0.2)
def generate_ae_c_quizzes(models, local_device=main_device):
criteria = [
+ # c_quiz_criterion_only_one,
c_quiz_criterion_one_good_one_bad,
+ # c_quiz_criterion_one_good_no_very_bad,
# c_quiz_criterion_diff,
# c_quiz_criterion_diff2,
# c_quiz_criterion_two_good,