forward_to_backward = torch.cat(
[
quizzes[:, 0:1],
- quizzes[:, 2 + self.prompt_len :],
- quizzes[:, 1 + self.prompt_len : 2 + self.prompt_len],
+ quizzes[:, 2 + self.prompt_len : 2 + self.prompt_len + self.answer_len],
+ quizzes[:, 1 + self.prompt_len : 1 + self.prompt_len + 1],
quizzes[:, 1 : 1 + self.prompt_len],
],
dim=1,
)
+
forward_to_backward[:, 0] = self.token_backward
forward_to_backward[:, 1 + self.answer_len] = self.token_backward
if result_dir is not None:
self.save_quizzes(
- result_dir, "culture_w_quizzes", self.train_w_quizzes[:72]
+ result_dir,
+ "culture_w_quizzes",
+ self.train_w_quizzes[:72],
+ n_backward=self.train_w_quizzes[:72, 0] == self.token_backward,
)
- # toto = self.reverse_time(self.train_w_quizzes[:72])
- # self.save_quizzes(result_dir, "toto", toto)
- # exit(0)
-
- def save_quizzes(self, result_dir, filename_prefix, quizzes, prediction=False):
+ def save_quizzes(
+ self,
+ result_dir,
+ filename_prefix,
+ quizzes,
+ n_backward=None,
+ mistakes=None,
+ ):
+ quizzes = quizzes.clone()
forward = quizzes[quizzes[:, 0] == self.token_forward]
ib = quizzes[:, 0] == self.token_backward
backward = quizzes[ib]
assert forward.size(0) + backward.size(0) == quizzes.size(0)
quizzes[ib] = self.reverse_time(quizzes[ib])
- if prediction:
- predicted_prompts = ib
- predicted_answers = torch.logical_not(ib)
- else:
+ if n_backward is None:
predicted_prompts = None
predicted_answers = None
+ else:
+ predicted_prompts = n_backward.long()
+ predicted_answers = 1 - predicted_prompts
+ if mistakes is not None:
+ # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
+ predicted_prompts *= mistakes
+ predicted_answers *= mistakes
+ else:
+ # 0/2 ~ not-to-predict / to predict
+ predicted_prompts *= 2
+ predicted_answers *= 2
self.problem.save_quizzes(
result_dir,
def produce_results(
self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
):
- def compute_accuracy(input):
+ def compute_accuracy(input, log_prefix=None):
ar_mask = self.make_ar_mask(input)
result = input.clone() * (1 - ar_mask)
seq_logproba = torch.empty(input.size(0), device=self.device)
device=self.device,
)
- if self.back_accuracy:
- n_forward = input[:, 0] == self.token_forward
- nb_total = input[n_forward].size(0)
- nb_correct = (
- (input[n_forward] == result[n_forward])
- .long()
- .min(dim=1)
- .values.sum()
- )
+ correct = torch.empty(input.size(0), dtype=torch.int64, device=input.device)
+
+ n_forward = input[:, 0] == self.token_forward
+ n_backward = input[:, 0] == self.token_backward
- n_backward = input[:, 0] == self.token_backward
+ correct[n_forward] = (
+ (input[n_forward] == result[n_forward]).long().min(dim=1).values
+ )
+
+ if self.back_accuracy and n_backward.any():
+ # accuracy of B->A*->B*=B instead of B->A*=A
back_input = self.reverse_time(result[n_backward])
- if back_input.size(0) > 0:
- back_input[:, 2 + self.prompt_len :] = input[
- n_backward, 2 + self.prompt_len :
- ]
- back_nb_total, back_nb_correct = compute_accuracy(back_input)
- nb_total += back_nb_total
- nb_correct += back_nb_correct
- else:
- nb_total = input.size(0)
- nb_correct = (input == result).long().min(dim=1).values.sum()
+ back_input[:, 2 + self.prompt_len :] = input[
+ n_backward, 1 : 1 + self.answer_len
+ ]
+ result[n_backward], correct[n_backward] = compute_accuracy(back_input)
- return nb_total, nb_correct
+ if log_prefix is not None:
+ forward_nb_correct = correct[n_forward].sum()
+ forward_nb_total = correct[n_forward].size(0)
+ backward_nb_correct = correct[n_backward].sum()
+ backward_nb_total = correct[n_backward].size(0)
- train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes[:nmax])
+ self.logger(
+ f"forward_accuracy {log_prefix} {n_epoch} {model.id=} {forward_nb_correct} / {forward_nb_total}"
+ )
- self.logger(
- f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
- )
+ self.logger(
+ f"backward_accuracy {log_prefix} {n_epoch} {model.id=} {backward_nb_correct} / {backward_nb_total}"
+ )
+
+ return result, correct
- test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes[:nmax])
+ compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train")
- self.logger(
- f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+ test_result, test_correct = compute_accuracy(
+ self.test_w_quizzes[:nmax], log_prefix="test"
)
- main_test_accuracy = test_nb_correct / test_nb_total
+ main_test_accuracy = test_correct.sum() / test_correct.size(0)
self.logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
##############################
- input = self.test_w_quizzes[:96]
- ar_mask = self.make_ar_mask(input)
- result = input.clone() * (1 - ar_mask)
- seq_logproba = torch.empty(input.size(0), device=self.device)
-
- masked_inplace_autoregression(
- model=model,
- batch_size=self.batch_size,
- input=result,
- ar_mask=ar_mask,
- seq_logproba=seq_logproba,
- temperature=1.0,
- deterministic_synthesis=deterministic_synthesis,
- progress_bar_desc=None,
- device=self.device,
- )
-
self.save_quizzes(
result_dir,
f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
- quizzes=result[:72],
- prediction=True,
+ quizzes=test_result[:72],
+ n_backward=self.test_w_quizzes[:72, 0] == self.token_backward,
+ mistakes=test_correct[:72] * 2 - 1,
)
return main_test_accuracy