run_tests(model, quiz_machine)
- threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values
- threshold = threshold[threshold.size(0) // 2]
+ # threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values
+ # threshold = threshold[threshold.size(0) // 2]
- model.hard_w_quizzes = torch.cat(
- [x[l >= threshold] for x, l in hard_w_quizzes], dim=0
- )
+ # model.hard_w_quizzes = torch.cat(
+ # [x[l >= threshold] for x, l in hard_w_quizzes], dim=0
+ # )
model.to(main_device)
nb = 0
# We consider all the configurations that we train for
- for struct, mask, noise_mask in self.understood_structures:
+ for struct, mask, _ in self.understood_structures:
i = self.problem.indices_select(quizzes=input, struct=struct)
nb += i.long().sum()
result[i], correct[i] = self.predict(