input,
ar_mask,
seq_logproba,
- temperature=1.0,
- deterministic_synthesis=False,
+ temperature,
+ deterministic_synthesis,
):
to_generate = (ar_mask.sum(0) > 0).nonzero()
n_backward = quizzes[:, 0] == self.token_backward
backward = quizzes[n_backward]
quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
- return not self.problem.trivial_prompts_and_answers(
- quizzes[:, 1 : 1 + self.prompt_len],
- quizzes[:, 2 + self.prompt_len :],
+ return torch.logical_not(
+ self.problem.trivial_prompts_and_answers(
+ quizzes[:, 1 : 1 + self.prompt_len],
+ quizzes[:, 2 + self.prompt_len :],
+ )
)
def reverse_time(self, quizzes):
quizzes,
mistakes=None,
):
- quizzes = quizzes.clone()
+ quizzes = quizzes.clone().to("cpu")
n_forward = quizzes[quizzes[:, 0] == self.token_forward]
n_backward = quizzes[:, 0] == self.token_backward
backward = quizzes[n_backward]
predicted_answers = 1 - predicted_prompts
if mistakes is not None:
# 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
- predicted_prompts *= mistakes
- predicted_answers *= mistakes
+ predicted_prompts *= mistakes.to("cpu")
+ predicted_answers *= mistakes.to("cpu")
else:
# 0/2 ~ not-to-predict / to predict
predicted_prompts *= 2
else:
self.test_c_quizzes.append(new_c_quizzes)
+ def logproba_solution(self, models, c_quizzes):
+ logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models))
+
+ for model in models:
+ for input, l in zip(
+ c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
+ ):
+ ar_mask = self.make_ar_mask(input)
+ output = model(mygpt.BracketedSequence(input)).x
+ ce = (
+ F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+ * ar_mask
+ )
+ l[:, model.id] = ce.sum(dim=-1)
+
+ return logproba
+
+ ###############################################################
+
def compute_correctness(
self,
c_quizzes,