From e12aa3370fb4d89803746e72bfb931b182e39592 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 24 Jul 2024 18:22:50 +0200 Subject: [PATCH] Update. --- grids.py | 38 ++++++++++++------ main.py | 23 ++++++----- quiz_machine.py | 103 ++++++++++++++++++++++++++---------------------- 3 files changed, 93 insertions(+), 71 deletions(-) diff --git a/grids.py b/grids.py index 37ed6a0..5ddcf32 100755 --- a/grids.py +++ b/grids.py @@ -131,7 +131,8 @@ class Grids(problem.Problem): def get_structure(self, quizzes): S = self.height * self.width struct = tuple( - self.tok2l[n.item()] for n in quizzes.reshape(-1, 4, S + 1)[0, :, 0] + self.tok2l[n.item()] + for n in quizzes.reshape(quizzes.size(0), 4, S + 1)[0, :, 0] ) self.check_structure(quizzes, struct) return struct @@ -143,8 +144,8 @@ class Grids(problem.Problem): sf = dict((l, n) for n, l in enumerate(struct_from)) result = quizzes.new(quizzes.size()) - q = quizzes.reshape(-1, 4, S + 1) - r = result.reshape(-1, 4, S + 1) + q = quizzes.reshape(quizzes.size(0), 4, S + 1) + r = result.reshape(result.size(0), 4, S + 1) r[:, 0] = q[:, sf[struct[0]], :] r[:, 1] = q[:, sf[struct[1]], :] @@ -153,12 +154,21 @@ class Grids(problem.Problem): return result + def non_trivial(self, quizzes): + S = self.height * self.width + assert self.check_structure(quizzes, struct=("A", "f_A", "B", "f_B")) + a = quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:] + return (a[:, 0] == a[:, 1]).min(dim=1).values & (a[:, 2] == a[:, 3]).min( + dim=1 + ).values + def make_ar_mask(self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)): - assert check_structure(quizzes, struct) + assert self.check_structure(quizzes, struct) ar_mask = quizzes.new_zeros(quizzes.size()) - a = ar_mask.reshape(-1, 4, -1)[:, :, 1:] + S = self.height * self.width + a = ar_mask.reshape(ar_mask.size(0), 4, S + 1)[:, :, 1:] a[:, 0, :] = mask[0] a[:, 1, :] = mask[1] a[:, 2, :] = mask[2] @@ -168,7 +178,7 @@ class Grids(problem.Problem): def indices_select(self, quizzes, struct=("A", "f_A", "B", "f_B")): S = self.height * self.width - q = quizzes.reshape(-1, 4, S + 1) + q = quizzes.reshape(quizzes.size(0), 4, S + 1) return ( (q[:, 0, 0] == self.l2tok[struct[0]]) & (q[:, 1, 0] == self.l2tok[struct[1]]) @@ -286,11 +296,13 @@ class Grids(problem.Problem): nrow=4, margin=8, ): + quizzes = quizzes.to("cpu") + S = self.height * self.width A, f_A, B, f_B = ( - quizzes.reshape(-1, 4, S + 1)[:, :, 1:] - .reshape(-1, 4, self.height, self.width) + quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:] + .reshape(quizzes.size(0), 4, self.height, self.width) .permute(1, 0, 2, 3) ) @@ -1382,14 +1394,16 @@ class Grids(problem.Problem): def create_empty_quizzes(self, nb, struct=("A", "f_A", "B", "f_B")): S = self.height * self.width quizzes = torch.zeros(nb, 4 * (S + 1), dtype=torch.int64) - quizzes[:, 0 * (S + 1)] = self.l2tok(struct[0]) - quizzes[:, 1 * (S + 1)] = self.l2tok(struct[1]) - quizzes[:, 2 * (S + 1)] = self.l2tok(struct[2]) - quizzes[:, 3 * (S + 1)] = self.l2tok(struct[3]) + quizzes[:, 0 * (S + 1)] = self.l2tok[struct[0]] + quizzes[:, 1 * (S + 1)] = self.l2tok[struct[1]] + quizzes[:, 2 * (S + 1)] = self.l2tok[struct[2]] + quizzes[:, 3 * (S + 1)] = self.l2tok[struct[3]] return quizzes def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False): + S = self.height * self.width + if tasks is None: tasks = self.all_tasks diff --git a/main.py b/main.py index fcca116..deba848 100755 --- a/main.py +++ b/main.py @@ -324,7 +324,7 @@ log_string(f"vocabulary_size {vocabulary_size}") ###################################################################### -def run_tests(model, quiz_machine, deterministic_synthesis, local_device=main_device): +def run_tests(model, quiz_machine, local_device=main_device): with torch.autograd.no_grad(): model.eval().to(local_device) @@ -355,7 +355,6 @@ def run_tests(model, quiz_machine, deterministic_synthesis, local_device=main_de model=model, input=full_input[:2000], result_dir=args.result_dir, - deterministic_synthesis=deterministic_synthesis, ) @@ -408,7 +407,7 @@ def one_epoch(model, quiz_machine, local_device=main_device): log_string(f"train_perplexity {n_epoch} model {model.id} {train_perplexity}") - run_tests(model, quiz_machine, deterministic_synthesis=False) + run_tests(model, quiz_machine) threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values threshold = threshold[threshold.size(0) // 2] @@ -455,7 +454,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 # We discard the trivial ones, according to a criterion # specific to the world quizzes (e.g. B=f(B)) - c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] + c_quizzes = c_quizzes[quiz_machine.problem.non_trivial(c_quizzes)] # We go through nb_rounds rounds and keep only quizzes on # which @@ -471,6 +470,9 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 remains = [c_quizzes.size(0)] for r in range(args.nb_rounds): + if c_quizzes.size(0) == 0: + break + number_correct_responses += quiz_machine.models_successes(models, c_quizzes) nb_sure_correct = (number_correct_responses == r + 1).long().sum(dim=1) @@ -487,9 +489,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 remains.append(c_quizzes.size(0)) - if c_quizzes.size(0) == 0: - break - if c_quizzes.size(0) > 0: nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0) recorded_validated.append(c_quizzes) @@ -550,9 +549,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 v = " ".join([str(n.item()) for n in r]) f.write(f"{n}: {v}\n") - quiz_machine.save_quiz_illustrations( - args.result_dir, prefix, vq, show_part_to_predict=False - ) + quiz_machine.save_quizzes_as_image(args.result_dir, prefix, vq) ###################################################################### @@ -727,8 +724,10 @@ for n_epoch in range(current_epoch, args.nb_epochs): temperature_cold=args.temperature_cold, ) - quiz_machine.save_quiz_illustrations( - args.result_dir, f"non_validated_{n_epoch:04d}_{model.id:02d}", c_quizzes + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + f"non_validated_{n_epoch:04d}_{model.id:02d}.png", + c_quizzes, ) # Renew the training samples diff --git a/quiz_machine.py b/quiz_machine.py index bb62181..2fb196c 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -32,6 +32,9 @@ def one_batch_masked_inplace_autoregression( logit_transformer=None, deterministic_synthesis=False, ): + if input.size(0) == 0: + return + to_generate = (ar_mask.sum(0) > 0).nonzero() if to_generate.min() > 0: @@ -174,11 +177,11 @@ class QuizMachine: ###################################################################### - def predict(self, input, struct, mask): + def predict(self, model, quizzes, struct, mask): ar_mask = self.problem.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask) result = quizzes * (1 - ar_mask) - seq_logproba = torch.empty(fwd_quizzes, device=self.device) + seq_logproba = torch.empty(quizzes.size(0), device=self.device) masked_inplace_autoregression( model=model, @@ -186,50 +189,47 @@ class QuizMachine: input=result, ar_mask=ar_mask, seq_logproba=seq_logproba, - deterministic_synthesis=deterministic_synthesis, + deterministic_synthesis=False, progress_bar_desc="accuracy", device=self.device, ) - nb_correct = (result == quizzes).min(dim=1).long() + correct = (result == quizzes).min(dim=1).values return result, correct def produce_results( - self, n_epoch, model, input, result_dir, deterministic_synthesis + self, + n_epoch, + model, + input, + result_dir, ): input = input.to(self.device) - i = self.problem.indices_select(quizzes=input, struct=struct) - - input_fwd = input[i] - test_result_fwd, test_correct_fwd = predict( - input_fwd, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) - ) - - input_bck = self.problem.reconfigure( - predict(input[i == False], ("f_B", "f_A", "B", "A"), (0, 1, 1, 1))[0], - struct=("A", "f_A", "B", "f_B"), - ) - - l = input_bck.size(1) // 4 - input_bck[:, 3 * l :] = input[i == False][:, :l] + result = input.new(input.size()) + correct = torch.empty(input.size(0), device=input.device, dtype=torch.bool) + + nb = 0 + for struct, mask in [ + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1)), + (("f_B", "f_A", "B", "A"), (0, 1, 1, 1)), + ]: + i = self.problem.indices_select(quizzes=input, struct=struct) + nb += i.long().sum() + result[i], correct[i] = self.predict( + model=model, quizzes=input[i], struct=struct, mask=mask + ) - test_result_bck, test_correct_bck = predict( - input_bck, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) - ) + assert nb == input.size(0) - main_test_accuracy = test_correct.sum() / test_correct.size(0) + main_test_accuracy = correct.sum() / correct.size(0) ############################## - test_result = torch.cat([test_result_fwd[:64], test_result_bck[:64]], dim=0) - test_correct = torch.cat([test_correct_fwd[:64], test_correct_bck[:64]], dim=0) - - self.save_quiz_illustrations( + self.problem.save_quizzes_as_image( result_dir, - f"culture_prediction_{n_epoch:04d}_{model.id:02d}", - quizzes=test_result, - # mistakes=test_correct, + f"culture_prediction_{n_epoch:04d}_{model.id:02d}.png", + quizzes=result[:128], ) return main_test_accuracy @@ -355,12 +355,19 @@ class QuizMachine: seq_logproba[...] = 0.0 + c_quizzes = c_quizzes.to(self.device) + print(self.problem.get_structure(c_quizzes)) + reversed_c_quizzes = self.problem.reconfigure( + c_quizzes, ("f_A", "A", "f_B", "B") + ) + for model in models_for_validation: # A, f(A), B | f(B) - c_quizzes = c_quizzes.to(self.device) result = c_quizzes.clone() - ar_mask = self.problem.make_ar_mask(result, shape="fwd_3_bck_3") + ar_mask = self.problem.make_ar_mask( + result, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1) + ) masked_inplace_autoregression( model=model, @@ -377,10 +384,11 @@ class QuizMachine: # ------------------------------- # f(A), A, f(B) | B - c_quizzes = self.problem.flip(c_quizzes, pairwise_flip=True).to(self.device) - result = c_quizzes.clone() + result = reversed_c_quizzes.clone() - ar_mask = self.problem.make_ar_mask(result, shape="fwd_3_bck_3") + ar_mask = self.problem.make_ar_mask( + result, ("f_A", "A", "f_B", "B"), mask=(0, 0, 0, 1) + ) masked_inplace_autoregression( model=model, @@ -392,7 +400,7 @@ class QuizMachine: device=self.device, ) - correct *= (c_quizzes == result).long().min(dim=-1).values + correct *= (reversed_c_quizzes == result).long().min(dim=-1).values # ------------------------------- @@ -409,11 +417,8 @@ class QuizMachine: temperature_hot=1.0, temperature_cold=1.0, ): - c_quizzes = torch.empty( - nb, - self.problem.seq_len, - device=self.device, - dtype=torch.int64, + c_quizzes = self.problem.create_empty_quizzes(nb, ("f_B", "f_A", "A", "B")).to( + self.device ) seq_logproba = torch.zeros(nb, device=self.device) @@ -426,13 +431,13 @@ class QuizMachine: # ) # lt_clean = None - c_quizzes[...] = self.problem.token_backward - masked_inplace_autoregression( model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"), + ar_mask=self.problem.make_ar_mask( + c_quizzes, ("f_B", "f_A", "A", "B"), (1, 0, 0, 0) + ), seq_logproba=seq_logproba, logit_transformer=lt_noisy, deterministic_synthesis=False, @@ -443,20 +448,24 @@ class QuizMachine: model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), + ar_mask=self.problem.make_ar_mask( + c_quizzes, ("f_B", "f_A", "A", "B"), (0, 1, 1, 1) + ), seq_logproba=seq_logproba, logit_transformer=lt_clean, deterministic_synthesis=False, device=self.device, ) - c_quizzes = self.problem.p_a_flip(c_quizzes) + c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B")) masked_inplace_autoregression( model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), + ar_mask=self.problem.make_ar_mask( + c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) + ), seq_logproba=seq_logproba, logit_transformer=lt_clean, deterministic_synthesis=False, -- 2.39.5