def standard_validity(logproba):
l = logproba.sort(dim=-1).values
- return (l[:, 0] < math.log(0.5)) & (l[:, 1] > math.log(0.95))
+ return (l[:, 0] < math.log(0.5)) & (l[:, 1] > math.log(0.99))
+ # warnings.warn("TEST!!!", RuntimeWarning)
+ # print(l.exp())
+ # return (l[:, 0] < math.log(0.99))
def valid_c_quizzes(recorded, criteria):
# ------------------------------------------------------------
- standard_validity = lambda nb_correct: torch.logical_and(
- nb_correct >= args.min_to_validate, nb_correct <= args.max_to_validate
+ standard_validity = lambda nb_correct: (nb_correct >= args.min_to_validate) & (
+ nb_correct <= args.max_to_validate
)
file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat")
F.cross_entropy(output.transpose(1, 2), input, reduction="none")
* ar_mask
)
- l[:, model.id] = ce.sum(dim=-1)
+ l[:, model.id] = -ce.sum(dim=-1)
return logproba