X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quiz_machine.py;h=cdfba85e6a2abb8d6cd1b14bb3b0f76fe2afad30;hb=693af34e144cd20d2dde6a508a190d49c1a76c7f;hp=c1477c9bb497498ddbe5aa4b2cc898ba1a915796;hpb=f7a8cd141a39039048d3e9311220a33079f2cfc7;p=culture.git diff --git a/quiz_machine.py b/quiz_machine.py index c1477c9..cdfba85 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -235,23 +235,21 @@ class QuizMachine: self.prompt_len = None self.answer_len = None - self.train_w_quizzes = self.generate_token_sequences(nb_train_samples) - self.reverse_random_half_in_place(self.train_w_quizzes) - self.train_w_quizzes = self.train_w_quizzes.to(device) + # self.train_w_quizzes = self.generate_token_sequences(nb_train_samples) + # self.reverse_random_half_in_place(self.train_w_quizzes) - self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device) - self.reverse_random_half_in_place(self.test_w_quizzes) - self.test_w_quizzes = self.test_w_quizzes.to(device) + # self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device) + # self.reverse_random_half_in_place(self.test_w_quizzes) self.train_c_quizzes = [] self.test_c_quizzes = [] - if result_dir is not None: - self.save_quizzes( - result_dir, - "culture_w_quizzes", - self.train_w_quizzes[:72], - ) + # if result_dir is not None: + # self.save_quizzes( + # result_dir, + # "culture_w_quizzes", + # self.train_w_quizzes[:72], + # ) def save_quizzes( self, @@ -287,13 +285,13 @@ class QuizMachine: predicted_answers, ) - def batches(self, split="train", desc=None): + def batches(self, model, split="train", desc=None): assert split in {"train", "test"} if split == "train": - w_quizzes = self.train_w_quizzes + w_quizzes = model.train_w_quizzes c_quizzes = self.train_c_quizzes else: - w_quizzes = self.test_w_quizzes + w_quizzes = model.test_w_quizzes c_quizzes = self.test_c_quizzes if len(c_quizzes) > 0: @@ -382,10 +380,10 @@ class QuizMachine: return result, correct - compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train") + compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train") test_result, test_correct = compute_accuracy( - self.test_w_quizzes[:nmax], log_prefix="test" + model.test_w_quizzes[:nmax], log_prefix="test" ) main_test_accuracy = test_correct.sum() / test_correct.size(0) @@ -402,8 +400,8 @@ class QuizMachine: return main_test_accuracy - def renew_w_quizzes(self, nb, for_train=True): - input = self.train_w_quizzes if for_train else self.test_w_quizzes + def renew_w_quizzes(self, model, nb, for_train=True): + input = model.train_w_quizzes if for_train else model.test_w_quizzes nb = min(nb, input.size(0)) input[:-nb] = input[nb:].clone() fresh_w_quizzes = self.generate_token_sequences(nb) @@ -429,7 +427,7 @@ class QuizMachine: F.cross_entropy(output.transpose(1, 2), input, reduction="none") * ar_mask ) - l[:, model.id] = ce.sum(dim=-1) + l[:, model.id] = -ce.sum(dim=-1) return logproba @@ -507,7 +505,10 @@ class QuizMachine: def generate_quizzes(self, nb, model_for_generation, temperature=1.0): c_quizzes = torch.empty( - nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64 + nb, + self.prompt_len + self.answer_len + 2, + device=self.device, + dtype=torch.int64, ) seq_logproba = torch.zeros(nb, device=self.device)