parser.add_argument("--proba_not_understands", type=float, default=0.5)
-parser.add_argument("--generation_temperature", type=float, default=2.5)
+parser.add_argument("--generation_temperature", type=float, default=1.5)
parser.add_argument("--c_quiz_validation_mode", type=str, default="predict")
)
elif args.c_quiz_validation_mode == "predict":
- to_keep = quiz_machine.solution_nb_correct(models, quizzes) == (len(models) - 1)
+ nc = quiz_machine.solution_nb_correct(models, quizzes)
+ count_nc = tuple(
+ n.item() for n in F.one_hot(nc, num_classes=len(models) + 1).sum(dim=0)
+ )
+ log_string(f"nb_correct {count_nc}")
+ to_keep = nc == (len(models) - 1)
else:
raise ValueError(f"{args.c_quiz_validation_mode=}")
input=c_quizzes,
ar_mask=self.make_ar_mask(c_quizzes),
seq_logproba=seq_logproba,
- temperature=1,
+ temperature=0.75,
deterministic_synthesis=False,
device=self.device,
)
input=c_quizzes,
ar_mask=self.make_ar_mask(c_quizzes),
seq_logproba=seq_logproba,
- temperature=1,
+ temperature=0.75,
deterministic_synthesis=False,
device=self.device,
)