result_dir,
"culture_w_quizzes",
self.train_w_quizzes[:72],
- show_to_be_predicted=True,
+ n_backward=self.train_w_quizzes[:72, 0] == self.token_backward,
)
def save_quizzes(
result_dir,
filename_prefix,
quizzes,
- show_to_be_predicted=False,
+ n_backward=None,
mistakes=None,
):
quizzes = quizzes.clone()
assert forward.size(0) + backward.size(0) == quizzes.size(0)
quizzes[ib] = self.reverse_time(quizzes[ib])
- if show_to_be_predicted:
- predicted_prompts = ib.long()
+ if n_backward is None:
+ predicted_prompts = None
+ predicted_answers = None
+ else:
+ predicted_prompts = n_backward.long()
predicted_answers = 1 - predicted_prompts
if mistakes is not None:
# 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
# 0/2 ~ not-to-predict / to predict
predicted_prompts *= 2
predicted_answers *= 2
- else:
- predicted_prompts = None
- predicted_answers = None
self.problem.save_quizzes(
result_dir,
def produce_results(
self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
):
- def compute_accuracy(input):
+ def compute_accuracy(input, log_prefix=None):
ar_mask = self.make_ar_mask(input)
result = input.clone() * (1 - ar_mask)
seq_logproba = torch.empty(input.size(0), device=self.device)
device=self.device,
)
- if self.back_accuracy:
- # If back_accuracy is True, we compute the accuracy on
- # the backward quizzes not by counting how many time
- # the real prompt A is equal to the reconstructed
- # prompt A*, but how many time the answers B* computed
- # from A* is equal to the correct answer. So we look
- # for the accuracy of A->B*=B for the forward, but for
- # the backward we look at B->A*->B*=B instead of B->A*=A
-
- n_forward = input[:, 0] == self.token_forward
- nb_total = input[n_forward].size(0)
- nb_correct = (
- (input[n_forward] == result[n_forward])
- .long()
- .min(dim=1)
- .values.sum()
- .item()
- )
+ correct = torch.empty(input.size(0), dtype=torch.int64, device=input.device)
- n_backward = input[:, 0] == self.token_backward
- back_input = self.reverse_time(result[n_backward])
+ n_forward = input[:, 0] == self.token_forward
+ n_backward = input[:, 0] == self.token_backward
- if back_input.size(0) > 0:
- back_input[:, 2 + self.prompt_len :] = input[
- n_backward, 1 : 1 + self.answer_len
- ]
- back_nb_total, back_nb_correct = compute_accuracy(back_input)
-
- self.logger(
- f"accuracy {n_epoch=} {model.id=} {nb_correct} / {nb_total}"
- )
- self.logger(
- f"back_accuracy {n_epoch=} {model.id=} {back_nb_correct} / {back_nb_total}"
- )
-
- nb_total += back_nb_total
- nb_correct += back_nb_correct
- else:
- self.logger(
- f"accuracy {n_epoch=} {model.id=} {nb_correct} / {nb_total}"
- )
+ correct[n_forward] = (
+ (input[n_forward] == result[n_forward]).long().min(dim=1).values
+ )
- else:
- nb_total = input.size(0)
- nb_correct = (input == result).long().min(dim=1).values.sum()
+ if self.back_accuracy and n_backward.any():
+ # accuracy of B->A*->B*=B instead of B->A*=A
+ back_input = self.reverse_time(result[n_backward])
+ back_input[:, 2 + self.prompt_len :] = input[
+ n_backward, 1 : 1 + self.answer_len
+ ]
+ result[n_backward], correct[n_backward] = compute_accuracy(back_input)
- return nb_total, nb_correct
+ if log_prefix is not None:
+ forward_nb_correct = correct[n_forward].sum()
+ forward_nb_total = correct[n_forward].size(0)
+ backward_nb_correct = correct[n_backward].sum()
+ backward_nb_total = correct[n_backward].size(0)
- train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes[:nmax])
+ self.logger(
+ f"forward_accuracy {log_prefix} {n_epoch} {model.id=} {forward_nb_correct} / {forward_nb_total}"
+ )
- self.logger(
- f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
- )
+ self.logger(
+ f"backward_accuracy {log_prefix} {n_epoch} {model.id=} {backward_nb_correct} / {backward_nb_total}"
+ )
- test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes[:nmax])
+ return result, correct
- self.logger(
- f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+ compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train")
+
+ test_result, test_correct = compute_accuracy(
+ self.test_w_quizzes[:nmax], log_prefix="test"
)
- main_test_accuracy = test_nb_correct / test_nb_total
+ main_test_accuracy = test_correct.sum() / test_correct.size(0)
self.logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
##############################
- input = self.test_w_quizzes[:96]
- ar_mask = self.make_ar_mask(input)
- result = input.clone() * (1 - ar_mask)
- seq_logproba = torch.empty(input.size(0), device=self.device)
-
- masked_inplace_autoregression(
- model=model,
- batch_size=self.batch_size,
- input=result,
- ar_mask=ar_mask,
- seq_logproba=seq_logproba,
- temperature=1.0,
- deterministic_synthesis=deterministic_synthesis,
- progress_bar_desc=None,
- device=self.device,
- )
-
- mistakes = (input == result).flatten(1).long().min(dim=1).values * 2 - 1
-
self.save_quizzes(
result_dir,
f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
- quizzes=result[:72],
- show_to_be_predicted=True,
- mistakes=mistakes[:72],
+ quizzes=test_result[:72],
+ n_backward=self.test_w_quizzes[:72, 0] == self.token_backward,
+ mistakes=test_correct[:72] * 2 - 1,
)
return main_test_accuracy