From f80ee61414c46fd5a5ea55c849756ee38ea5a6b6 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 24 Jul 2024 15:21:13 +0200 Subject: [PATCH] Update. --- grids.py | 59 ++++++---- quiz_machine.py | 295 +++++++++++++----------------------------------- 2 files changed, 115 insertions(+), 239 deletions(-) diff --git a/grids.py b/grids.py index eaba99a..99a9240 100755 --- a/grids.py +++ b/grids.py @@ -136,19 +136,6 @@ class Grids(problem.Problem): self.check_structure(quizzes, struct) return struct - def make_ar_mask(self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)): - assert check_structure(quizzes, struct) - - ar_mask = quizzes.new_zeros(quizzes.size()) - - a = ar_mask.reshape(-1, 4, -1)[:, :, 1:] - a[:, 0, :] = mask[0] - a[:, 1, :] = mask[1] - a[:, 2, :] = mask[2] - a[:, 3, :] = mask[3] - - return ar_mask - def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")): S = self.height * self.width @@ -166,6 +153,29 @@ class Grids(problem.Problem): return result + def make_ar_mask(self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)): + assert check_structure(quizzes, struct) + + ar_mask = quizzes.new_zeros(quizzes.size()) + + a = ar_mask.reshape(-1, 4, -1)[:, :, 1:] + a[:, 0, :] = mask[0] + a[:, 1, :] = mask[1] + a[:, 2, :] = mask[2] + a[:, 3, :] = mask[3] + + return ar_mask + + def indices_select(self, quizzes, struct=("A", "f_A", "B", "f_B")): + S = self.height * self.width + q = quizzes.reshape(-1, 4, S + 1) + return ( + (q[:, 0, 0] == self.l2tok[struct[0]]) + & (q[:, 1, 0] == self.l2tok[struct[1]]) + & (q[:, 2, 0] == self.l2tok[struct[2]]) + & (q[:, 3, 0] == self.l2tok[struct[3]]) + ) + def __init__( self, max_nb_cached_chunks=None, @@ -1368,12 +1378,6 @@ class Grids(problem.Problem): ###################################################################### - def trivial_prompts_and_answers(self, prompts, answers): - S = self.height * self.width - Bs = prompts[:, 2 * (S + 1) + 1 : 2 * (S + 1) + S + 1] - f_Bs = answers[:, 1:] - return (Bs == f_Bs).long().min(dim=-1).values > 0 - def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False): if tasks is None: tasks = self.all_tasks @@ -1425,8 +1429,21 @@ if __name__ == "__main__": nb = 5 quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill]) print(grids.get_structure(quizzes)) - blah = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B")) - print(grids.get_structure(blah)) + quizzes = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B")) + print(grids.get_structure(quizzes)) + + i = torch.rand(quizzes.size(0)) < 0.5 + + quizzes[i] = grids.reconfigure(quizzes[i], struct=("f_B", "f_A", "B", "A")) + + j = grids.indices_select(quizzes, struct=("f_B", "f_A", "B", "A")) + + print( + i.equal(j), + grids.get_structure(quizzes[j]), + grids.get_structure(quizzes[j == False]), + ) + exit(0) # nb = 1000 diff --git a/quiz_machine.py b/quiz_machine.py index d62ba3b..bc2a358 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -17,36 +17,6 @@ from mygpt import BracketedSequence import threading -###################################################################### -# if output is log(P(X=y)) and target is Y, returns -log P(X=Y) + H(X -# | X != Y) - - -# output is NxCxT and target is NxT -def confusion(output, target, reduction="mean"): - N, C, T = output.shape - output = output.permute(0, 2, 1).reshape(-1, C) - target = target.flatten() - all_t = torch.arange(N * T, device=output.device) - output = output.log_softmax(dim=-1) - result = -output[all_t, target] - - output[all_t, target] = float("-inf") - output = output.log_softmax(dim=-1) - e = output.exp() - output[all_t, target] = 0 - result = result - (output * e).sum(-1) - - if reduction == "none": - return result.reshape(N, T) - elif reduction == "mean": - return result.reshape(N, T).mean() - elif reduction == "sum": - return result.reshape(N, T).sum() - else: - raise ValueError(f"unknown reduction '{reduction}'.") - - ###################################################################### # ar_mask is a tensor with 0s and 1s, of same shape as input, with @@ -139,47 +109,6 @@ def masked_inplace_autoregression( class QuizMachine: - def indices_p2a_and_a2p(self, quizzes): - i_p2a = quizzes[:, 0] == self.problem.token_forward - j_p2a = quizzes[:, self.prompt_len] == self.problem.token_forward - i_a2p = quizzes[:, 0] == self.problem.token_backward - j_a2p = quizzes[:, self.answer_len] == self.problem.token_backward - assert ((i_p2a & j_p2a) | (i_a2p & j_a2p)).all() - return i_p2a, i_a2p - - def non_trivial(self, quizzes): - quizzes = quizzes.clone() - i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes) - quizzes[i_a2p] = self.problem.p_a_flip(quizzes[i_a2p]) # a_fa_b_fb - return torch.logical_not( - self.problem.trivial_prompts_and_answers( - quizzes[:, : self.prompt_len], quizzes[:, self.prompt_len :] - ) - ) - - def p_a_flip_half_in_place(self, quizzes): - i = torch.rand(quizzes.size(0)) < 0.5 - if i.any(): - quizzes[i] = self.problem.p_a_flip(quizzes[i]) - - def generate_token_sequences(self, nb): - prompts, answers = self.problem.generate_prompts_and_answers(nb) - - if self.prompt_len is None: - self.prompt_len = prompts.size(1) - - if self.answer_len is None: - self.answer_len = answers.size(1) - - assert prompts.size(1) == self.prompt_len and answers.size(1) == self.answer_len - - result = [] - - for prompt, answer in zip(prompts, answers): - result.append(torch.cat([prompt, answer], dim=0)[None, :]) - - return torch.cat(result, dim=0) - def __init__( self, problem, @@ -191,8 +120,6 @@ class QuizMachine: ): super().__init__() - self.nb_token_values = problem.nb_token_values() - self.problem = problem self.back_accuracy = back_accuracy self.batch_size = batch_size @@ -205,47 +132,8 @@ class QuizMachine: self.train_c_quizzes = [] self.test_c_quizzes = [] - def save_quiz_illustrations( - self, - result_dir, - filename_prefix, - quizzes, - mistakes=None, - show_part_to_predict=True, - ): - quizzes = quizzes.clone().to("cpu") - i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes) - p2a = quizzes[i_p2a] - a2p = quizzes[i_a2p] - assert p2a.size(0) + a2p.size(0) == quizzes.size(0) - quizzes[i_a2p] = self.problem.p_a_flip(quizzes[i_a2p]) - - if show_part_to_predict: - predicted_prompts = i_a2p.long() - predicted_answers = 1 - predicted_prompts - if mistakes is not None: - # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct - predicted_prompts *= mistakes.to("cpu") - predicted_answers *= mistakes.to("cpu") - else: - # 0/2 ~ not-to-predict / to predict - predicted_prompts *= 2 - predicted_answers *= 2 - else: - predicted_prompts = None - predicted_answers = None - - self.problem.save_quiz_illustrations( - result_dir, - filename_prefix, - quizzes[:, : self.prompt_len], - quizzes[:, self.prompt_len :], - predicted_prompts, - predicted_answers, - ) - def vocabulary_size(self): - return self.nb_token_values + return self.problem.nb_token_values ###################################################################### @@ -289,11 +177,12 @@ class QuizMachine: def produce_results( self, n_epoch, model, input, result_dir, deterministic_synthesis ): - def compute_accuracy(input, log_prefix=None): - input = input.to(self.device) - ar_mask = self.problem.make_ar_mask(input, shape="fwd_3_bck_123") - result = input.clone() * (1 - ar_mask) - seq_logproba = torch.empty(input.size(0), device=self.device) + def predict(input, struct, mask): + ar_mask = self.problem.make_ar_mask( + quizzes=quizzes, struct=struct, mask=mask + ) + result = quizzes * (1 - ar_mask) + seq_logproba = torch.empty(fwd_quizzes, device=self.device) masked_inplace_autoregression( model=model, @@ -306,37 +195,29 @@ class QuizMachine: device=self.device, ) - correct = torch.empty(input.size(0), dtype=torch.int64, device=input.device) - - i_p2a, i_a2p = self.indices_p2a_and_a2p(input) - - correct[i_p2a] = (input[i_p2a] == result[i_p2a]).long().min(dim=1).values - - if self.back_accuracy and i_a2p.any(): - # accuracy of B->A*->B*=B instead of B->A*=A - back_input = self.problem.p_a_flip(result[i_a2p]) - back_input[:, 1 + self.prompt_len :] = input[i_a2p, 1 : self.answer_len] - _, correct[i_a2p] = compute_accuracy(back_input) - - if log_prefix is not None: - p2a_nb_correct = correct[i_p2a].sum() - p2a_nb_total = correct[i_p2a].size(0) - a2p_nb_correct = correct[i_a2p].sum() - a2p_nb_total = correct[i_a2p].size(0) - - self.logger( - f"{log_prefix}_accuracy {n_epoch} model {model.id} p2a {p2a_nb_correct} / {p2a_nb_total} a2p {a2p_nb_correct} / {a2p_nb_total}" - ) + nb_correct = (result == quizzes).min(dim=1).long() return result, correct - test_result, test_correct = compute_accuracy(input, log_prefix="test") + input = input.to(self.device) + i = self.problem.indices_select(quizzes=input, struct=struct) - n_test_p2a = input[:, 0] == self.problem.token_forward + test_result_fwd, test_correct_fwd = predict( + input[i], ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) + ) - p2a_test_correct = test_correct[n_test_p2a] + input_bck = self.problem.reconfigure( + predict(input[i == False], ("f_B", "f_A", "B", "A"), (0, 1, 1, 1))[0], + struct=("A", "f_A", "B", "f_B"), + ) + + l = input_bck.size(1) + input_bck[:, 3 * l :] = input[i == False][:, :l] + test_result_bck, test_correct_bck = predict( + input_bck, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) + ) - main_test_accuracy = p2a_test_correct.sum() / p2a_test_correct.size(0) + main_test_accuracy = test_correct.sum() / test_correct.size(0) ############################## @@ -351,19 +232,27 @@ class QuizMachine: ###################################################################### - def create_w_quizzes( - self, model, nb_train_samples, nb_test_samples, p2a_only=False - ): - model.train_w_quizzes = self.generate_token_sequences(nb_train_samples) - model.test_w_quizzes = self.generate_token_sequences(nb_test_samples) + def flip_half_in_place(self, quizzes): + r = torch.randint(quizzes.size(0), device=quizzes.device) < 0.5 + i = self.problem.indices_select(quizzes=input, struct=("A", "f_A", "B", "f_B")) + quizzes[i & r] = self.problem.reconfigure( + quizzes[i & r], struct=("f_B", "f_A", "B", "A") + ) + j = self.problem.indices_select(quizzes=input, struct=("f_B", "f_A", "B", "A")) + quizzes[j & r] = self.problem.reconfigure( + quizzes[j & r], struct=("A", "f_A", "B", "f_B") + ) - if not p2a_only: - self.p_a_flip_half_in_place(model.train_w_quizzes) - self.p_a_flip_half_in_place(model.test_w_quizzes) + def create_w_quizzes(self, model, nb_train_samples, nb_test_samples): + model.train_w_quizzes = self.problem.generate_w_quizzes(nb_train_samples) + model.test_w_quizzes = self.problem.generate_w_quizzes(nb_test_samples) + + self.flip_half_in_place(model.train_w_quizzes) + self.flip_half_in_place(model.test_w_quizzes) ###################################################################### - def renew_train_w_quizzes(self, model, p2a_only=False): + def renew_train_w_quizzes(self, model): if hasattr(model, "hard_w_quizzes"): self.logger( f"re-using {model.hard_w_quizzes.size(0)} hard world quizzes from model {model.id}" @@ -379,19 +268,18 @@ class QuizMachine: model.train_w_quizzes[...] = torch.cat( [ model.hard_w_quizzes, - self.generate_token_sequences( + self.problem.generate_w_quizzes( model.train_w_quizzes.size(0) - model.hard_w_quizzes.size(0) ), ], dim=0, ) else: - model.train_w_quizzes[...] = self.generate_token_sequences( + model.train_w_quizzes[...] = self.problem.generate_w_quizzes( model.train_w_quizzes.size(0) ) - if not p2a_only: - self.p_a_flip_half_in_place(model.train_w_quizzes) + self.flip_half_in_place(model.train_w_quizzes) ###################################################################### @@ -481,9 +369,7 @@ class QuizMachine: # ------------------------------- # f(A), A, f(B) | B - c_quizzes = self.problem.p_a_flip(c_quizzes, pairwise_flip=True).to( - self.device - ) + c_quizzes = self.problem.flip(c_quizzes, pairwise_flip=True).to(self.device) result = c_quizzes.clone() ar_mask = self.problem.make_ar_mask(result, shape="fwd_3_bck_3") @@ -512,7 +398,6 @@ class QuizMachine: self, nb, model_for_generation, - p2a_only=False, temperature_hot=1.0, temperature_cold=1.0, ): @@ -533,68 +418,42 @@ class QuizMachine: # ) # lt_clean = None - if p2a_only: - c_quizzes[...] = self.problem.token_forward + c_quizzes[...] = self.problem.token_backward - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"), - seq_logproba=seq_logproba, - logit_transformer=lt_noisy, - deterministic_synthesis=False, - device=self.device, - ) - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), - seq_logproba=seq_logproba, - logit_transformer=lt_clean, - deterministic_synthesis=False, - device=self.device, - ) - - else: - c_quizzes[...] = self.problem.token_backward - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"), - seq_logproba=seq_logproba, - logit_transformer=lt_noisy, - deterministic_synthesis=False, - device=self.device, - ) + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"), + seq_logproba=seq_logproba, + logit_transformer=lt_noisy, + deterministic_synthesis=False, + device=self.device, + ) - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), - seq_logproba=seq_logproba, - logit_transformer=lt_clean, - deterministic_synthesis=False, - device=self.device, - ) + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), + seq_logproba=seq_logproba, + logit_transformer=lt_clean, + deterministic_synthesis=False, + device=self.device, + ) - c_quizzes = self.problem.p_a_flip(c_quizzes) + c_quizzes = self.problem.p_a_flip(c_quizzes) - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), - seq_logproba=seq_logproba, - logit_transformer=lt_clean, - deterministic_synthesis=False, - device=self.device, - ) + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), + seq_logproba=seq_logproba, + logit_transformer=lt_clean, + deterministic_synthesis=False, + device=self.device, + ) return c_quizzes.to("cpu") -- 2.20.1