self.prompt_len = None
self.answer_len = None
- self.train_w_quizzes = self.generate_token_sequences(nb_train_samples)
- self.reverse_random_half_in_place(self.train_w_quizzes)
- self.train_w_quizzes = self.train_w_quizzes.to(device)
+ # self.train_w_quizzes = self.generate_token_sequences(nb_train_samples)
+ # self.reverse_random_half_in_place(self.train_w_quizzes)
- self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device)
- self.reverse_random_half_in_place(self.test_w_quizzes)
- self.test_w_quizzes = self.test_w_quizzes.to(device)
+ # self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device)
+ # self.reverse_random_half_in_place(self.test_w_quizzes)
self.train_c_quizzes = []
self.test_c_quizzes = []
- if result_dir is not None:
- self.save_quizzes(
- result_dir,
- "culture_w_quizzes",
- self.train_w_quizzes[:72],
- )
+ # if result_dir is not None:
+ # self.save_quizzes(
+ # result_dir,
+ # "culture_w_quizzes",
+ # self.train_w_quizzes[:72],
+ # )
def save_quizzes(
self,
predicted_answers,
)
- def batches(self, split="train", desc=None):
+ def batches(self, model, split="train", desc=None):
assert split in {"train", "test"}
if split == "train":
- w_quizzes = self.train_w_quizzes
+ w_quizzes = model.train_w_quizzes
c_quizzes = self.train_c_quizzes
else:
- w_quizzes = self.test_w_quizzes
+ w_quizzes = model.test_w_quizzes
c_quizzes = self.test_c_quizzes
if len(c_quizzes) > 0:
return result, correct
- compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train")
+ compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train")
test_result, test_correct = compute_accuracy(
- self.test_w_quizzes[:nmax], log_prefix="test"
+ model.test_w_quizzes[:nmax], log_prefix="test"
)
main_test_accuracy = test_correct.sum() / test_correct.size(0)
return main_test_accuracy
- def renew_w_quizzes(self, nb, for_train=True):
- input = self.train_w_quizzes if for_train else self.test_w_quizzes
+ def renew_w_quizzes(self, model, nb, for_train=True):
+ input = model.train_w_quizzes if for_train else model.test_w_quizzes
nb = min(nb, input.size(0))
input[:-nb] = input[nb:].clone()
fresh_w_quizzes = self.generate_token_sequences(nb)
F.cross_entropy(output.transpose(1, 2), input, reduction="none")
* ar_mask
)
- l[:, model.id] = ce.sum(dim=-1)
+ l[:, model.id] = -ce.sum(dim=-1)
return logproba
def generate_quizzes(self, nb, model_for_generation, temperature=1.0):
c_quizzes = torch.empty(
- nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
+ nb,
+ self.prompt_len + self.answer_len + 2,
+ device=self.device,
+ dtype=torch.int64,
)
seq_logproba = torch.zeros(nb, device=self.device)