acc_train_loss += loss.item() * input.size(0)
loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1)
- n_p2a = input[:, 0] == quiz_machine.token_p2a
+ n_p2a = input[:, 0] == quiz_machine.problem.token_forward
to_store = from_w & n_p2a.to("cpu")
if to_store.any():
hard_w_quizzes.append(
v_train = validated_quizzes[:nb_for_train]
quiz_machine.store_c_quizzes(v_train, for_train=True)
- quiz_machine.store_c_quizzes(quiz_machine.reverse_time(v_train), for_train=True)
+ quiz_machine.store_c_quizzes(quiz_machine.p_a_flip(v_train), for_train=True)
v_test = validated_quizzes[nb_for_train:nb_to_create]
quiz_machine.store_c_quizzes(v_test, for_train=False)
- quiz_machine.store_c_quizzes(quiz_machine.reverse_time(v_test), for_train=False)
+ quiz_machine.store_c_quizzes(quiz_machine.p_a_flip(v_test), for_train=False)
######################################################################
# save images
self.problem.save_quiz_illustrations(
result_dir,
filename_prefix,
- quizzes[:, 1 : 1 + self.prompt_len],
- quizzes[:, 2 + self.prompt_len :],
+ quizzes[:, : self.prompt_len],
+ quizzes[:, self.prompt_len :],
predicted_prompts,
predicted_answers,
)
if self.back_accuracy and n_a2p.any():
# accuracy of B->A*->B*=B instead of B->A*=A
back_input = self.p_a_flip(result[n_a2p])
- back_input[:, 2 + self.prompt_len :] = input[
- n_a2p, 1 : 1 + self.answer_len
- ]
+ back_input[:, 1 + self.prompt_len :] = input[n_a2p, 1 : self.answer_len]
_, correct[n_a2p] = compute_accuracy(back_input)
if log_prefix is not None: