parser.add_argument("--reverse_cleanup", action="store_true", default=False)
+parser.add_argument("--validation_forward_only", action="store_true", default=False)
+
parser.add_argument("--problem", type=str, default="sky")
parser.add_argument("--nb_gpts", type=int, default=5)
sum_nb_c_quizzes += c_quizzes.size(0)
nb_correct = quizz_machine.compute_correctness(
- c_quizzes, models, both_direction=True
+ c_quizzes, models, both_directions=not args.validation_forward_only
)
if args.dirty_debug:
f"test_set_composition w_quizzes {quizz_machine.nb_batch_w_quizzes} c_quizzes {quizz_machine.nb_batch_c_quizzes}"
)
- cta = " ".join([f"{float(m.main_test_accuracy):.02f}" for m in models])
+ cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
log_string(f"current_test_accuracies {cta}")
# replace a fraction of the w_quizzes with a fresh ones
return torch.cat([c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1)
def compute_correctness(
- self, c_quizzes, models_for_validation, both_direction=True
+ self, c_quizzes, models_for_validation, both_directions=True
):
reversed_c_quizzes = self.reverse_time(c_quizzes)
correct = (c_quizzes == result).long().min(dim=-1).values
- if both_direction:
+ if both_directions:
reversed_result = reversed_c_quizzes.clone()
masked_inplace_autoregression(