def generate_token_sequences(self, nb):
prompts, answers = self.problem.generate_prompts_and_answers(nb)
- print(f"DEBUG {prompts.size()=} {answers.size()=}")
- sys.stdout.flush()
-
if self.prompt_len is None:
self.prompt_len = prompts.size(1)
######################################################################
- def logproba_of_solutions(self, models, c_quizzes):
+ def solution_token_logprobas(self, models, c_quizzes):
logproba = c_quizzes.new_zeros(
- c_quizzes.size(0), len(models), device=self.device, dtype=torch.float32
+ c_quizzes.size(0),
+ len(models),
+ c_quizzes.size(1),
+ device=self.device,
+ dtype=torch.float32,
)
for model in models:
input = input.to(self.device)
ar_mask = self.make_ar_mask(input)
output = model(mygpt.BracketedSequence(input)).x
- ce = (
- F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+ l[:, model.id] = (
+ -F.cross_entropy(
+ output.transpose(1, 2), input, reduction="none"
+ )
* ar_mask
)
- l[:, model.id] = -ce.sum(dim=-1)
model.train(t)