From b7d4cb766c0fbb0f054465782c3761cf33a74896 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 23 Jul 2024 15:41:13 +0200 Subject: [PATCH] Update. --- grids.py | 1 + main.py | 19 ++++++- quiz_machine.py | 133 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 151 insertions(+), 2 deletions(-) diff --git a/grids.py b/grids.py index 5778f85..406c0b7 100755 --- a/grids.py +++ b/grids.py @@ -1413,6 +1413,7 @@ class Grids(problem.Problem): m = (d < self.height * self.width).long() X[i, j] = c[-1] f_X[...] = m * c[-1] + (1 - m) * f_X + f_X[i, j] = 0 if accept_full or (d * (X == 0)).max() == self.height * self.width: break diff --git a/main.py b/main.py index 61820dd..7feb3b9 100755 --- a/main.py +++ b/main.py @@ -441,7 +441,7 @@ def one_epoch(model, quiz_machine, local_device=main_device): def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100): nb_to_validate = nb_for_train + nb_for_test - nb_to_generate_per_iteration = nb_to_validate + nb_to_generate_per_iteration = max(args.physical_batch_size, nb_to_validate) nb_validated = 0 recorded_validated = [] @@ -485,6 +485,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 # This is nb_quizzes x nb_models number_correct_responses = 0 + remains = [c_quizzes.size(0)] + for r in range(args.nb_rounds): number_correct_responses += quiz_machine.models_successes(models, c_quizzes) @@ -500,7 +502,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 c_quizzes = c_quizzes[to_keep] number_correct_responses = number_correct_responses[to_keep] - log_string(f"round {r} remains {c_quizzes.size(0)}") + remains.append(c_quizzes.size(0)) if c_quizzes.size(0) == 0: break @@ -528,6 +530,9 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 else: e = "???" + v = " ".join([x.item() for x in remains]) + log_string(f"filter c_quizzes {v}") + log_string( f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)" ) @@ -552,6 +557,16 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 if vq.size(0) > 0: prefix = f"culture_c_quiz_{n_epoch:04d}" + + number_correct_responses = 0 + for r in range(args.nb_rounds): + number_correct_responses += quiz_machine.models_successes(models, vq) + + with open(os.path.join(args.result_dir, prefix + "_responses.dat"), "w") as f: + for n, r in enumerate(number_correct_responses): + v = " ".join([str(n.item()) for n in r]) + f.write(f"{n}: {v}\n") + quiz_machine.save_quiz_illustrations( args.result_dir, prefix, vq, show_part_to_predict=False ) diff --git a/quiz_machine.py b/quiz_machine.py index d6c686e..4b07de3 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -605,3 +605,136 @@ class QuizMachine: return c_quizzes.to("cpu") ###################################################################### + + def generate_c_quizzes_mixing( + self, + nb, + model_for_generation, + p2a_only=False, + temperature_hot=1.0, + temperature_cold=1.0, + ): + c_quizzes = torch.empty( + nb, + self.prompt_len + self.answer_len, + device=self.device, + dtype=torch.int64, + ) + + c_quizzes_1 = torch.empty( + nb, + self.prompt_len + self.answer_len, + device=self.device, + dtype=torch.int64, + ) + + c_quizzes_2 = torch.empty( + nb, + self.prompt_len + self.answer_len, + device=self.device, + dtype=torch.int64, + ) + + seq_logproba = torch.zeros(nb, device=self.device) + + lt_noisy = lambda s, logits: logits / temperature_hot + lt_clean = lambda s, logits: logits / temperature_cold + + ###################################################################### + + c_quizzes_1[...] = self.problem.token_backward + ar_mask = self.problem.make_ar_mask(c_quizzes_1, shape="fwd_012_bck_0") + + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes_1, + ar_mask=ar_mask, + seq_logproba=seq_logproba, + logit_transformer=lt_noisy, + deterministic_synthesis=False, + device=self.device, + ) + + self.save_quiz_illustrations("/tmp", f"c_quizzes_1", c_quizzes_1) + + c_quizzes_2[...] = self.problem.token_backward + + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes_2, + ar_mask=ar_mask, + seq_logproba=seq_logproba, + logit_transformer=lt_noisy, + deterministic_synthesis=False, + device=self.device, + ) + + self.save_quiz_illustrations("/tmp", f"c_quizzes_2", c_quizzes_2) + + h = len(model_for_generation.trunk) // 2 + + with torch.autograd.no_grad(): + t = model_for_generation.training + model_for_generation.eval() + + bs1 = model_for_generation.partial_forward( + mygpt.BracketedSequence(c_quizzes_1), end_layer=h + ) + bs2 = model_for_generation.partial_forward( + mygpt.BracketedSequence(c_quizzes_2), end_layer=h + ) + + alpha = 0.5 + + output = model_for_generation.partial_forward( + mygpt.BracketedSequence(alpha * bs1.x + (1 - alpha) * bs2.x), + start_layer=h, + ).x + + dist = torch.distributions.categorical.Categorical(logits=output) + c_quizzes[...] = dist.sample() + + c_quizzes[...] = ( + ar_mask * c_quizzes + (1 - ar_mask) * self.problem.token_backward + ) + + model_for_generation.train(t) + + self.save_quiz_illustrations("/tmp", f"c_quizzes", c_quizzes) + + ###################################################################### + + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), + seq_logproba=seq_logproba, + logit_transformer=lt_clean, + deterministic_synthesis=False, + device=self.device, + ) + + self.save_quiz_illustrations("/tmp", f"c_quizzes_A", c_quizzes) + + c_quizzes = self.problem.p_a_flip(c_quizzes) + + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), + seq_logproba=seq_logproba, + logit_transformer=lt_clean, + deterministic_synthesis=False, + device=self.device, + ) + + self.save_quiz_illustrations("/tmp", f"c_quizzes_B", c_quizzes) + + print("DONE") + exit(0) + + return c_quizzes.to("cpu") -- 2.20.1