X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=quiz_machine.py;h=cdfba85e6a2abb8d6cd1b14bb3b0f76fe2afad30;hb=693af34e144cd20d2dde6a508a190d49c1a76c7f;hp=f0fb4082beb73913460bd4f72d4448e22ed15a1c;hpb=93cea45f62046a3481d6c05ab2cfe70f6dbc93b3;p=culture.git diff --git a/quiz_machine.py b/quiz_machine.py index f0fb408..cdfba85 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -235,23 +235,21 @@ class QuizMachine: self.prompt_len = None self.answer_len = None - self.train_w_quizzes = self.generate_token_sequences(nb_train_samples) - self.reverse_random_half_in_place(self.train_w_quizzes) - self.train_w_quizzes = self.train_w_quizzes.to(device) + # self.train_w_quizzes = self.generate_token_sequences(nb_train_samples) + # self.reverse_random_half_in_place(self.train_w_quizzes) - self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device) - self.reverse_random_half_in_place(self.test_w_quizzes) - self.test_w_quizzes = self.test_w_quizzes.to(device) + # self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device) + # self.reverse_random_half_in_place(self.test_w_quizzes) self.train_c_quizzes = [] self.test_c_quizzes = [] - if result_dir is not None: - self.save_quizzes( - result_dir, - "culture_w_quizzes", - self.train_w_quizzes[:72], - ) + # if result_dir is not None: + # self.save_quizzes( + # result_dir, + # "culture_w_quizzes", + # self.train_w_quizzes[:72], + # ) def save_quizzes( self, @@ -260,7 +258,7 @@ class QuizMachine: quizzes, mistakes=None, ): - quizzes = quizzes.clone() + quizzes = quizzes.clone().to("cpu") n_forward = quizzes[quizzes[:, 0] == self.token_forward] n_backward = quizzes[:, 0] == self.token_backward backward = quizzes[n_backward] @@ -271,8 +269,8 @@ class QuizMachine: predicted_answers = 1 - predicted_prompts if mistakes is not None: # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct - predicted_prompts *= mistakes - predicted_answers *= mistakes + predicted_prompts *= mistakes.to("cpu") + predicted_answers *= mistakes.to("cpu") else: # 0/2 ~ not-to-predict / to predict predicted_prompts *= 2 @@ -287,13 +285,13 @@ class QuizMachine: predicted_answers, ) - def batches(self, split="train", desc=None): + def batches(self, model, split="train", desc=None): assert split in {"train", "test"} if split == "train": - w_quizzes = self.train_w_quizzes + w_quizzes = model.train_w_quizzes c_quizzes = self.train_c_quizzes else: - w_quizzes = self.test_w_quizzes + w_quizzes = model.test_w_quizzes c_quizzes = self.test_c_quizzes if len(c_quizzes) > 0: @@ -382,10 +380,10 @@ class QuizMachine: return result, correct - compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train") + compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train") test_result, test_correct = compute_accuracy( - self.test_w_quizzes[:nmax], log_prefix="test" + model.test_w_quizzes[:nmax], log_prefix="test" ) main_test_accuracy = test_correct.sum() / test_correct.size(0) @@ -402,8 +400,8 @@ class QuizMachine: return main_test_accuracy - def renew_w_quizzes(self, nb, for_train=True): - input = self.train_w_quizzes if for_train else self.test_w_quizzes + def renew_w_quizzes(self, model, nb, for_train=True): + input = model.train_w_quizzes if for_train else model.test_w_quizzes nb = min(nb, input.size(0)) input[:-nb] = input[nb:].clone() fresh_w_quizzes = self.generate_token_sequences(nb) @@ -416,6 +414,25 @@ class QuizMachine: else: self.test_c_quizzes.append(new_c_quizzes) + def logproba_solution(self, models, c_quizzes): + logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models)) + + for model in models: + for input, l in zip( + c_quizzes.split(self.batch_size), logproba.split(self.batch_size) + ): + ar_mask = self.make_ar_mask(input) + output = model(mygpt.BracketedSequence(input)).x + ce = ( + F.cross_entropy(output.transpose(1, 2), input, reduction="none") + * ar_mask + ) + l[:, model.id] = -ce.sum(dim=-1) + + return logproba + + ############################################################### + def compute_correctness( self, c_quizzes, @@ -488,7 +505,10 @@ class QuizMachine: def generate_quizzes(self, nb, model_for_generation, temperature=1.0): c_quizzes = torch.empty( - nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64 + nb, + self.prompt_len + self.answer_len + 2, + device=self.device, + dtype=torch.int64, ) seq_logproba = torch.zeros(nb, device=self.device)