X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quiz_machine.py;h=cdfba85e6a2abb8d6cd1b14bb3b0f76fe2afad30;hb=693af34e144cd20d2dde6a508a190d49c1a76c7f;hp=d4af77027d73c0017efe8cbecfaad8fcfc468cad;hpb=ceddc8cc3adbb045fdef1ccb0b3df2b8fed9eb4c;p=culture.git diff --git a/quiz_machine.py b/quiz_machine.py index d4af770..cdfba85 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -27,8 +27,8 @@ def one_batch_masked_inplace_autoregression( input, ar_mask, seq_logproba, - temperature=1.0, - deterministic_synthesis=False, + temperature, + deterministic_synthesis, ): to_generate = (ar_mask.sum(0) > 0).nonzero() @@ -50,7 +50,8 @@ def one_batch_masked_inplace_autoregression( t_next = dist.sample() all_n = torch.arange(t_next.size(0)) - seq_logproba += logits[all_n, t_next].sum(dim=-1) + + seq_logproba += logits[all_n, t_next] input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] @@ -116,6 +117,19 @@ class QuizMachine: ).all() return i_forward, i_backward + def non_trivial(self, quizzes): + quizzes = quizzes.clone() + n_forward = quizzes[quizzes[:, 0] == self.token_forward] + n_backward = quizzes[:, 0] == self.token_backward + backward = quizzes[n_backward] + quizzes[n_backward] = self.reverse_time(quizzes[n_backward]) + return torch.logical_not( + self.problem.trivial_prompts_and_answers( + quizzes[:, 1 : 1 + self.prompt_len], + quizzes[:, 2 + self.prompt_len :], + ) + ) + def reverse_time(self, quizzes): i_forward, i_backward = self.indices_forward_and_backward(quizzes) @@ -221,23 +235,21 @@ class QuizMachine: self.prompt_len = None self.answer_len = None - self.train_w_quizzes = self.generate_token_sequences(nb_train_samples) - self.reverse_random_half_in_place(self.train_w_quizzes) - self.train_w_quizzes = self.train_w_quizzes.to(device) + # self.train_w_quizzes = self.generate_token_sequences(nb_train_samples) + # self.reverse_random_half_in_place(self.train_w_quizzes) - self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device) - self.reverse_random_half_in_place(self.test_w_quizzes) - self.test_w_quizzes = self.test_w_quizzes.to(device) + # self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device) + # self.reverse_random_half_in_place(self.test_w_quizzes) self.train_c_quizzes = [] self.test_c_quizzes = [] - if result_dir is not None: - self.save_quizzes( - result_dir, - "culture_w_quizzes", - self.train_w_quizzes[:72], - ) + # if result_dir is not None: + # self.save_quizzes( + # result_dir, + # "culture_w_quizzes", + # self.train_w_quizzes[:72], + # ) def save_quizzes( self, @@ -246,7 +258,7 @@ class QuizMachine: quizzes, mistakes=None, ): - quizzes = quizzes.clone() + quizzes = quizzes.clone().to("cpu") n_forward = quizzes[quizzes[:, 0] == self.token_forward] n_backward = quizzes[:, 0] == self.token_backward backward = quizzes[n_backward] @@ -257,8 +269,8 @@ class QuizMachine: predicted_answers = 1 - predicted_prompts if mistakes is not None: # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct - predicted_prompts *= mistakes - predicted_answers *= mistakes + predicted_prompts *= mistakes.to("cpu") + predicted_answers *= mistakes.to("cpu") else: # 0/2 ~ not-to-predict / to predict predicted_prompts *= 2 @@ -273,13 +285,13 @@ class QuizMachine: predicted_answers, ) - def batches(self, split="train", desc=None): + def batches(self, model, split="train", desc=None): assert split in {"train", "test"} if split == "train": - w_quizzes = self.train_w_quizzes + w_quizzes = model.train_w_quizzes c_quizzes = self.train_c_quizzes else: - w_quizzes = self.test_w_quizzes + w_quizzes = model.test_w_quizzes c_quizzes = self.test_c_quizzes if len(c_quizzes) > 0: @@ -359,19 +371,19 @@ class QuizMachine: backward_nb_total = correct[n_backward].size(0) self.logger( - f"{log_prefix}_forward_accuracy {n_epoch} {model.id=} {forward_nb_correct} / {forward_nb_total}" + f"{log_prefix}_forward_accuracy {n_epoch} model {model.id} nb_correct {forward_nb_correct} / {forward_nb_total} ({forward_nb_correct*100/forward_nb_total} %)" ) self.logger( - f"{log_prefix}_backward_accuracy {n_epoch} {model.id=} {backward_nb_correct} / {backward_nb_total}" + f"{log_prefix}_backward_accuracy {n_epoch} model {model.id} nb_correct {backward_nb_correct} / {backward_nb_total} ({backward_nb_correct*100/backward_nb_total} %)" ) return result, correct - compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train") + compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train") test_result, test_correct = compute_accuracy( - self.test_w_quizzes[:nmax], log_prefix="test" + model.test_w_quizzes[:nmax], log_prefix="test" ) main_test_accuracy = test_correct.sum() / test_correct.size(0) @@ -388,8 +400,8 @@ class QuizMachine: return main_test_accuracy - def renew_w_quizzes(self, nb, for_train=True): - input = self.train_w_quizzes if for_train else self.test_w_quizzes + def renew_w_quizzes(self, model, nb, for_train=True): + input = model.train_w_quizzes if for_train else model.test_w_quizzes nb = min(nb, input.size(0)) input[:-nb] = input[nb:].clone() fresh_w_quizzes = self.generate_token_sequences(nb) @@ -402,6 +414,25 @@ class QuizMachine: else: self.test_c_quizzes.append(new_c_quizzes) + def logproba_solution(self, models, c_quizzes): + logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models)) + + for model in models: + for input, l in zip( + c_quizzes.split(self.batch_size), logproba.split(self.batch_size) + ): + ar_mask = self.make_ar_mask(input) + output = model(mygpt.BracketedSequence(input)).x + ce = ( + F.cross_entropy(output.transpose(1, 2), input, reduction="none") + * ar_mask + ) + l[:, model.id] = -ce.sum(dim=-1) + + return logproba + + ############################################################### + def compute_correctness( self, c_quizzes, @@ -420,11 +451,11 @@ class QuizMachine: nb_correct = 0 + seq_logproba[...] = 0.0 + for model in models_for_validation: result = c_quizzes.clone() - seq_logproba[...] = 0.0 - ar_mask = self.make_ar_mask(result) masked_inplace_autoregression( @@ -474,7 +505,10 @@ class QuizMachine: def generate_quizzes(self, nb, model_for_generation, temperature=1.0): c_quizzes = torch.empty( - nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64 + nb, + self.prompt_len + self.answer_len + 2, + device=self.device, + dtype=torch.int64, ) seq_logproba = torch.zeros(nb, device=self.device)