parser.add_argument("--generation_temperature", type=float, default=2)
+parser.add_argument("--c_quiz_validation_mode", type=str, default="proba")
+
parser.add_argument("--dirty_debug", action="store_true", default=False)
######################################################################
else:
raise ValueError
-problem.save_some_examples(args.result_dir)
+if not args.resume:
+ problem.save_some_examples(args.result_dir)
quiz_machine = quiz_machine.QuizMachine(
problem=problem,
def keep_good_quizzes(models, quizzes):
quizzes = quizzes[quiz_machine.non_trivial(quizzes)]
- token_logprobas = quiz_machine.solution_token_logprobas(models, quizzes)
- l = token_logprobas.sum(dim=-1).sort(dim=-1).values
+ if args.c_quiz_validation_mode == "proba":
+ token_logprobas = quiz_machine.solution_token_logprobas(models, quizzes)
+ l = token_logprobas.sum(dim=-1).sort(dim=-1).values
- to_keep = (l[:, 0] < math.log(args.proba_not_understands)) & (
- l[:, 1] > math.log(args.proba_understands)
- )
+ to_keep = (l[:, 0] < math.log(args.proba_not_understands)) & (
+ l[:, 1] > math.log(args.proba_understands)
+ )
+
+ elif args.c_quiz_validation_mode == "predict":
+ to_keep = quiz_machine.solution_nb_correct(models, quizzes) == (len(models) - 1)
+
+ else:
+ raise ValueError(f"{args.c_quiz_validation_mode=}")
if args.dirty_debug:
# warnings.warn("DEBUG", RuntimeWarning)
if nb_validated > 0 and nb_validated < nb_to_create:
d = (nb_to_create - nb_validated) * duration / nb_validated
+ e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime(
+ "%a %H:%M"
+ )
else:
- d = 0
-
- e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime(
- "%a %H:%M"
- )
+ e = "???"
log_string(
f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create} (finishes {e})"
###############################################################
- def compute_correctness(
+ def solution_nb_correct(
self,
- c_quizzes,
models_for_validation,
- bidirectional_validation=False,
- deterministic_validation=True,
+ c_quizzes,
+ deterministic_validation=False,
):
- if bidirectional_validation:
- backward_c_quizzes = self.forward_to_backward(c_quizzes)
-
seq_logproba = torch.zeros(
c_quizzes.size(0),
max([m.id for m in models_for_validation]) + 1,
seq_logproba[...] = 0.0
for model in models_for_validation:
+ c_quizzes = c_quizzes.to(self.device)
result = c_quizzes.clone()
ar_mask = self.make_ar_mask(result)
seq_logproba=seq_logproba[:, model.id],
temperature=1.0,
deterministic_synthesis=deterministic_validation,
- # progress_bar_desc="solving c_quizzes",
device=self.device,
)
correct = (c_quizzes == result).long().min(dim=-1).values
- if bidirectional_validation:
- backward_result = backward_c_quizzes.clone()
-
- ar_mask = self.make_ar_mask(backward_result)
-
- masked_inplace_autoregression(
- model=model,
- batch_size=self.batch_size,
- input=backward_result,
- ar_mask=ar_mask,
- seq_logproba=seq_logproba[:, model.id],
- temperature=1.0,
- deterministic_synthesis=deterministic_validation,
- # progress_bar_desc="solving backward c_quizzes",
- device=self.device,
- )
-
- backward_correct = (
- (backward_c_quizzes == backward_result).long().min(dim=-1).values
- )
-
- correct *= backward_correct
-
- # endif
-
nb_correct += correct
- return nb_correct, seq_logproba
+ return nb_correct.to("cpu")
###############################################################