From c1c4ba37480db7829a8b443340484a237bbd01fc Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 23 Jul 2024 00:05:45 +0200 Subject: [PATCH] Update. --- grids.py | 24 ++++-- main.py | 66 ++++++++++----- mygpt.py | 17 ++++ quiz_machine.py | 214 ++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 294 insertions(+), 27 deletions(-) diff --git a/grids.py b/grids.py index 22704b2..b531eb9 100755 --- a/grids.py +++ b/grids.py @@ -176,21 +176,31 @@ class Grids(problem.Problem): dim=1, ) else: - flipped = torch.cat( + flipped_from_forward = torch.cat( [ quizzes[:, 3 * (S + 1) : 3 * (S + 1) + S + 1], - quizzes[:, 2 * (S + 1) : 2 * (S + 1) + S + 1], + quizzes[:, 0 * (S + 1) : 2 * (S + 1) + S + 1], quizzes[:, 1 * (S + 1) : 1 * (S + 1) + S + 1], + quizzes[:, 2 * (S + 1) : 0 * (S + 1) + S + 1], + ], + dim=1, + ) + flipped_from_forward[:, torch.arange(4) * (S + 1)] = self.token_backward + + flipped_from_backward = torch.cat( + [ + quizzes[:, 1 * (S + 1) : 3 * (S + 1) + S + 1], + quizzes[:, 2 * (S + 1) : 2 * (S + 1) + S + 1], + quizzes[:, 3 * (S + 1) : 1 * (S + 1) + S + 1], quizzes[:, 0 * (S + 1) : 0 * (S + 1) + S + 1], ], dim=1, ) + flipped_from_backward[:, torch.arange(4) * (S + 1)] = self.token_forward + + m = (flipped[:, 0] == self.token_forward).long() - m = (flipped[:, 0] == self.token_forward).long() - flipped[:, 0 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward - flipped[:, 1 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward - flipped[:, 2 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward - flipped[:, 3 * (S + 1)] = m * self.token_backward + (1 - m) * self.token_forward + flipped = m * flipped_from_forward + (1 - m) * flipped_from_backward return flipped diff --git a/main.py b/main.py index f8f8502..a540cc0 100755 --- a/main.py +++ b/main.py @@ -87,6 +87,8 @@ parser.add_argument("--gpus", type=str, default="all") parser.add_argument("--nb_gpts", type=int, default=5) +parser.add_argument("--max_fail_to_validate", type=int, default=1) + parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975) parser.add_argument("--proba_understands", type=float, default=0.9) @@ -99,6 +101,8 @@ parser.add_argument("--temperature_cold", type=float, default=0.75) parser.add_argument("--nb_rounds", type=int, default=3) +parser.add_argument("--noise_level", type=float, default=0) + parser.add_argument("--c_quiz_validation_mode", type=str, default="predict") parser.add_argument("--p2a_only", action="store_true", default=False) @@ -374,9 +378,21 @@ def one_epoch(model, quiz_machine, local_device=main_device): if nb_train_samples % args.batch_size == 0: optimizer.zero_grad() + targets = input + + if args.noise_level > 0: + m = ( + (torch.rand(targets.size(), device=targets.device) < args.noise_level) + & (targets != quiz_machine.problem.token_forward) + & (targets != quiz_machine.problem.token_backward) + ).long() + input = (1 - m) * input.clone() + m * torch.randint( + vocabulary_size, input.size(), device=input.device + ) + output = model(mygpt.BracketedSequence(input)).x loss_per_token = F.cross_entropy( - output.transpose(1, 2), input, reduction="none" + output.transpose(1, 2), targets, reduction="none" ) loss = loss_per_token.mean() acc_train_loss += loss.item() * input.size(0) @@ -421,7 +437,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 nb_validated = 0 recorded_validated = [] - # recorded_too_simple = [] start_time = time.perf_counter() @@ -450,26 +465,33 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] # We go through nb_rounds rounds and keep only quizzes on - # which models respond always the same through rounds + # which models respond always the same through rounds and one + # which N-1 succeed and one fails + + ms = 0 # "model scores" - total_nb_validated = 0 - ms = 0 for r in range(args.nb_rounds): ms += quiz_machine.models_successes(models, c_quizzes) - # print(f"{r=} {ms=}") - i = ((ms == r + 1).long().sum(dim=1) == ms.size(1) - 1) & ( - (ms == 0).long().sum(dim=1) == 1 + nb_sure_and_correct = (ms == r + 1).long().sum(dim=1) + nb_sure_and_fail = (ms == 0).long().sum(dim=1) + to_keep = ( + (nb_sure_and_correct + nb_sure_and_fail == ms.size(1)) + & (nb_sure_and_fail >= 1) + & (nb_sure_and_fail <= args.max_fail_to_validate) ) - c_quizzes = c_quizzes[i] - ms = ms[i] + + c_quizzes = c_quizzes[to_keep] + ms = ms[to_keep] + print(f"Round {r} remains {c_quizzes.size(0)}") if c_quizzes.size(0) == 0: break if c_quizzes.size(0) > 0: nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0) - total_nb_validated = nb_validated_per_model.sum().item() recorded_validated.append(c_quizzes) + total_nb_validated = nb_validated_per_model.sum().item() + duration = time.perf_counter() - start_time if total_nb_validated > 0: @@ -492,7 +514,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 ) validated_quizzes = torch.cat(recorded_validated, dim=0) - # too_simple_quizzes = torch.cat(recorded_too_simple, dim=0) ###################################################################### # store the new c_quizzes which have been validated @@ -516,14 +537,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 args.result_dir, prefix, vq, show_part_to_predict=False ) - # vq = too_simple_quizzes[torch.randperm(too_simple_quizzes.size(0))[:128]] - - # if vq.size(0) > 0: - # prefix = f"culture_c_quiz_{n_epoch:04d}_too_simple" - # quiz_machine.save_quiz_illustrations( - # args.result_dir, prefix, vq, show_part_to_predict=False - # ) - ###################################################################### @@ -696,6 +709,19 @@ for n_epoch in range(current_epoch, args.nb_epochs): ) log_string(f"wrote {filename}") + for model in weakest_models: + c_quizzes = quiz_machine.generate_c_quizzes( + 128, + model_for_generation=model, + p2a_only=args.p2a_only, + temperature_hot=args.temperature_hot, + temperature_cold=args.temperature_cold, + ) + + quiz_machine.save_quiz_illustrations( + args.result_dir, f"non_validated_{n_epoch:04d}_{model.id:02d}", c_quizzes + ) + # Renew the training samples for model in weakest_models: diff --git a/mygpt.py b/mygpt.py index d0fda7e..51c0862 100755 --- a/mygpt.py +++ b/mygpt.py @@ -295,6 +295,23 @@ class MyGPT(nn.Module): bs = self.readout(bs) return bs + def partial_forward(self, bs, start_layer=None, end_layer=None): + if start_layer is None: + # print(f"GENERATE {bs.first} {bs.first+bs.nb}") + bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) + bs = self.embedding(bs) + if end_layer is not None: + return self.trunk[:end_layer](bs) + else: + bs = self.trunk(bs) + bs = self.readout(bs) + return bs + else: + bs = self.trunk[start_layer:](bs) + bs = self.trunk(bs) + bs = self.readout(bs) + return bs + def record_attention(self, v=True): for m in self.modules(): if isinstance(m, QKVAttention): diff --git a/quiz_machine.py b/quiz_machine.py index a5f9a89..182e9ff 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -601,3 +601,217 @@ class QuizMachine: ) return c_quizzes.to("cpu") + + ###################################################################### + + def generate_c_quizzes_fixed_point( + self, + nb, + model_for_generation, + p2a_only=False, + temperature_hot=1.0, + temperature_cold=1.0, + ): + c_quizzes = torch.empty( + nb, + self.prompt_len + self.answer_len, + device=self.device, + dtype=torch.int64, + ) + + seq_logproba = torch.zeros(nb, device=self.device) + + lt_noisy = lambda s, logits: logits / temperature_hot + lt_clean = lambda s, logits: logits / temperature_cold + + c_quizzes[...] = self.problem.token_backward + + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"), + seq_logproba=seq_logproba, + logit_transformer=lt_noisy, + deterministic_synthesis=False, + device=self.device, + ) + + self.save_quiz_illustrations("/tmp", f"c_quizzes_before", c_quizzes) + + c_quizzes = self.problem.p_a_flip(c_quizzes) + + while True: + print("ITERATION") + + c_quizzes = self.problem.p_a_flip(c_quizzes) + + pred = c_quizzes.clone() + + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), + seq_logproba=seq_logproba, + logit_transformer=lt_clean, + deterministic_synthesis=False, + device=self.device, + ) + + c_quizzes = self.problem.p_a_flip(c_quizzes) + + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), + seq_logproba=seq_logproba, + logit_transformer=lt_clean, + deterministic_synthesis=False, + device=self.device, + ) + + if pred[202:].equal(c_quizzes[202:]): + break + + self.save_quiz_illustrations("/tmp", f"c_quizzes_after", c_quizzes) + + exit(0) + + return c_quizzes.to("cpu") + + ###################################################################### + + def generate_c_quizzes_mixing( + self, + nb, + model_for_generation, + p2a_only=False, + temperature_hot=1.0, + temperature_cold=1.0, + ): + c_quizzes = torch.empty( + nb, + self.prompt_len + self.answer_len, + device=self.device, + dtype=torch.int64, + ) + + c_quizzes_1 = torch.empty( + nb, + self.prompt_len + self.answer_len, + device=self.device, + dtype=torch.int64, + ) + + c_quizzes_2 = torch.empty( + nb, + self.prompt_len + self.answer_len, + device=self.device, + dtype=torch.int64, + ) + + seq_logproba = torch.zeros(nb, device=self.device) + + lt_noisy = lambda s, logits: logits / temperature_hot + lt_clean = lambda s, logits: logits / temperature_cold + + ###################################################################### + + c_quizzes_1[...] = self.problem.token_backward + ar_mask = self.problem.make_ar_mask(c_quizzes_1, shape="fwd_012_bck_0") + + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes_1, + ar_mask=ar_mask, + seq_logproba=seq_logproba, + logit_transformer=lt_noisy, + deterministic_synthesis=False, + device=self.device, + ) + + self.save_quiz_illustrations("/tmp", f"c_quizzes_1", c_quizzes_1) + + c_quizzes_2[...] = self.problem.token_backward + + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes_2, + ar_mask=ar_mask, + seq_logproba=seq_logproba, + logit_transformer=lt_noisy, + deterministic_synthesis=False, + device=self.device, + ) + + self.save_quiz_illustrations("/tmp", f"c_quizzes_2", c_quizzes_2) + + h = len(model_for_generation.trunk) // 2 + + with torch.autograd.no_grad(): + t = model_for_generation.training + model_for_generation.eval() + + bs1 = model_for_generation.partial_forward( + mygpt.BracketedSequence(c_quizzes_1), end_layer=h + ) + bs2 = model_for_generation.partial_forward( + mygpt.BracketedSequence(c_quizzes_2), end_layer=h + ) + + alpha = 0.1 + + output = model_for_generation.partial_forward( + mygpt.BracketedSequence(alpha * bs1.x + (1 - alpha) * bs2.x), + start_layer=h, + ).x + + dist = torch.distributions.categorical.Categorical(logits=output) + c_quizzes[...] = dist.sample() + + c_quizzes[...] = ( + ar_mask * c_quizzes + (1 - ar_mask) * self.problem.token_backward + ) + + model_for_generation.train(t) + + self.save_quiz_illustrations("/tmp", f"c_quizzes", c_quizzes) + + ###################################################################### + + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), + seq_logproba=seq_logproba, + logit_transformer=lt_clean, + deterministic_synthesis=False, + device=self.device, + ) + + self.save_quiz_illustrations("/tmp", f"c_quizzes_A", c_quizzes) + + c_quizzes = self.problem.p_a_flip(c_quizzes) + + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), + seq_logproba=seq_logproba, + logit_transformer=lt_clean, + deterministic_synthesis=False, + device=self.device, + ) + + self.save_quiz_illustrations("/tmp", f"c_quizzes_B", c_quizzes) + + print("DONE") + exit(0) + + return c_quizzes.to("cpu") -- 2.39.5