From e9906eac2ee163b23163f7a0d1d144cf7df42306 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 22 Jul 2024 06:36:56 +0200 Subject: [PATCH] Update. --- grids.py | 32 +++++++------- main.py | 114 ++++++++++++++++++++---------------------------- quiz_machine.py | 47 +++++++++----------- 3 files changed, 84 insertions(+), 109 deletions(-) diff --git a/grids.py b/grids.py index d3e7dcc..22704b2 100755 --- a/grids.py +++ b/grids.py @@ -716,9 +716,9 @@ class Grids(problem.Problem): while True: error = False - N = torch.randint(5, (1,)).item() + 2 - c = torch.zeros(N + 1, dtype=torch.int64) - c[1:] = torch.randperm(len(self.colors) - 1)[:N] + 1 + N = 3 + c = torch.zeros(N + 2, dtype=torch.int64) + c[1:] = torch.randperm(len(self.colors) - 1)[: N + 1] + 1 for X, f_X in [(A, f_A), (B, f_B)]: if not hasattr(self, "cache_count") or len(self.cache_count) == 0: @@ -728,7 +728,7 @@ class Grids(problem.Problem): self.height, self.width, nb_seeds=self.height * self.width // 8, - nb_iterations=self.height * self.width // 10, + nb_iterations=self.height * self.width // 5, ) ) @@ -739,20 +739,20 @@ class Grids(problem.Problem): # V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % ( # c.size(0) - 1 # ) + 1 - V = torch.randint(c.size(0) - 1, (X.max() + 1,)) + 1 + + V = torch.randint(N, (X.max() + 1,)) + 1 V[0] = 0 NB = F.one_hot(c[V]).sum(dim=0) X[...] = c[V[X]] - - if F.one_hot(X.flatten()).max(dim=0).values.sum().item() == N + 1: - f_X[...] = 0 - for e in range(1, N + 1): - for j in range(NB[c[e]]): - if j < self.width: - f_X[e - 1, j] = c[e] - else: - error = True - break + f_X[...] = X + + if F.one_hot(X.flatten()).max(dim=0).values.sum().item() >= 3: + m = NB[c[:-1]].max() + if (NB[c[:-1]] == m).long().sum() == 1: + for e in range(1, N + 1): + if NB[c[e]] == m: + a = (f_X == c[e]).long() + f_X[...] = (1 - a) * f_X + a * c[-1] else: error = True break @@ -1423,7 +1423,7 @@ if __name__ == "__main__": # nb, nrow = 8, 2 # for t in grids.all_tasks: - for t in [grids.task_contact]: + for t in [grids.task_count]: # for t in [grids.task_symbols]: print(t.__name__) prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) diff --git a/main.py b/main.py index b7d0431..f8f8502 100755 --- a/main.py +++ b/main.py @@ -97,6 +97,8 @@ parser.add_argument("--temperature_hot", type=float, default=1.5) parser.add_argument("--temperature_cold", type=float, default=0.75) +parser.add_argument("--nb_rounds", type=int, default=3) + parser.add_argument("--c_quiz_validation_mode", type=str, default="predict") parser.add_argument("--p2a_only", action="store_true", default=False) @@ -413,57 +415,27 @@ def one_epoch(model, quiz_machine, local_device=main_device): ###################################################################### -def keep_good_quizzes(models, quizzes, required_nb_failures=1): - quizzes = quizzes[quiz_machine.non_trivial(quizzes)] - - if args.c_quiz_validation_mode == "proba": - token_logprobas = quiz_machine.solution_token_logprobas(models, quizzes) - l = token_logprobas.sum(dim=-1).sort(dim=-1).values - - to_keep = (l[:, 0] < math.log(args.proba_not_understands)) & ( - l[:, 1] > math.log(args.proba_understands) - ) - - elif args.c_quiz_validation_mode == "predict": - nc = quiz_machine.solution_nb_correct(models, quizzes) - - count_nc = tuple( - n.item() for n in F.one_hot(nc, num_classes=len(models) + 1).sum(dim=0) - ) - - log_string(f"nb_correct {count_nc}") - - to_keep = nc == (len(models) - required_nb_failures) - - else: - raise ValueError(f"{args.c_quiz_validation_mode=}") - - if args.dirty_debug: - # warnings.warn("DEBUG", RuntimeWarning) - to_keep = torch.rand(to_keep.size(), device=to_keep.device) < 0.5 - - return quizzes[to_keep] - - -###################################################################### - - def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100): - nb_to_create = nb_for_train + nb_for_test - nb_to_generate_per_iteration = nb_to_create + nb_to_validate = nb_for_train + nb_for_test + nb_to_generate_per_iteration = nb_to_validate nb_validated = 0 recorded_validated = [] - recorded_too_simple = [] + # recorded_too_simple = [] start_time = time.perf_counter() - nb_validated = torch.zeros(len(models), dtype=torch.int64) + nb_validated_per_model = torch.zeros(len(models), dtype=torch.int64) - while nb_validated.sum() < nb_to_create: + while nb_validated_per_model.sum() < nb_to_validate: # We balance the number of quizzes per model - model_for_generation = sorted(models, key=lambda m: nb_validated[m.id])[0] + model_for_generation = sorted( + models, key=lambda m: nb_validated_per_model[m.id] + )[0] + + # We generate quizzes with a procedure that injects some + # structured noise c_quizzes = quiz_machine.generate_c_quizzes( nb_to_generate_per_iteration, @@ -473,30 +445,40 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 temperature_cold=args.temperature_cold, ) - c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] - - nc = quiz_machine.solution_nb_correct(models, c_quizzes) - - count_nc = tuple( - n.item() for n in F.one_hot(nc, num_classes=len(models) + 1).sum(dim=0) - ) - - log_string(f"nb_correct {count_nc}") + # We discard the trivial ones - recorded_too_simple.append(c_quizzes[nc == len(models)]) + c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] - c_quizzes = c_quizzes[nc == len(models) - 1] + # We go through nb_rounds rounds and keep only quizzes on + # which models respond always the same through rounds - nb_validated[model_for_generation.id] += c_quizzes.size(0) - total_nb_validated = nb_validated.sum().item() + total_nb_validated = 0 + ms = 0 + for r in range(args.nb_rounds): + ms += quiz_machine.models_successes(models, c_quizzes) + # print(f"{r=} {ms=}") + i = ((ms == r + 1).long().sum(dim=1) == ms.size(1) - 1) & ( + (ms == 0).long().sum(dim=1) == 1 + ) + c_quizzes = c_quizzes[i] + ms = ms[i] + if c_quizzes.size(0) == 0: + break - recorded_validated.append(c_quizzes) + if c_quizzes.size(0) > 0: + nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0) + total_nb_validated = nb_validated_per_model.sum().item() + recorded_validated.append(c_quizzes) duration = time.perf_counter() - start_time if total_nb_validated > 0: - if total_nb_validated < nb_to_create: - d = (nb_to_create - total_nb_validated) * duration / total_nb_validated + if total_nb_validated < nb_to_validate: + d = ( + (nb_to_validate - total_nb_validated) + * duration + / total_nb_validated + ) e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime( "%a %H:%M" ) @@ -506,11 +488,11 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 e = "???" log_string( - f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)" + f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)" ) validated_quizzes = torch.cat(recorded_validated, dim=0) - too_simple_quizzes = torch.cat(recorded_too_simple, dim=0) + # too_simple_quizzes = torch.cat(recorded_too_simple, dim=0) ###################################################################### # store the new c_quizzes which have been validated @@ -519,7 +501,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 quiz_machine.store_c_quizzes(v_train, for_train=True) quiz_machine.store_c_quizzes(quiz_machine.problem.p_a_flip(v_train), for_train=True) - v_test = validated_quizzes[nb_for_train:nb_to_create] + v_test = validated_quizzes[nb_for_train:nb_to_validate] quiz_machine.store_c_quizzes(v_test, for_train=False) quiz_machine.store_c_quizzes(quiz_machine.problem.p_a_flip(v_test), for_train=False) @@ -534,13 +516,13 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 args.result_dir, prefix, vq, show_part_to_predict=False ) - vq = too_simple_quizzes[torch.randperm(too_simple_quizzes.size(0))[:128]] + # vq = too_simple_quizzes[torch.randperm(too_simple_quizzes.size(0))[:128]] - if vq.size(0) > 0: - prefix = f"culture_c_quiz_{n_epoch:04d}_too_simple" - quiz_machine.save_quiz_illustrations( - args.result_dir, prefix, vq, show_part_to_predict=False - ) + # if vq.size(0) > 0: + # prefix = f"culture_c_quiz_{n_epoch:04d}_too_simple" + # quiz_machine.save_quiz_illustrations( + # args.result_dir, prefix, vq, show_part_to_predict=False + # ) ###################################################################### diff --git a/quiz_machine.py b/quiz_machine.py index b965e33..5f14528 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -59,8 +59,8 @@ def one_batch_masked_inplace_autoregression( input, ar_mask, seq_logproba, - temperature, - deterministic_synthesis, + logit_transformer=None, + deterministic_synthesis=False, ): to_generate = (ar_mask.sum(0) > 0).nonzero() @@ -73,7 +73,7 @@ def one_batch_masked_inplace_autoregression( logits = output[:, s] - logits = (logits / temperature).log_softmax(dim=-1) + logits = logit_transformer(s, logits).log_softmax(dim=-1) if deterministic_synthesis: t_next = logits.argmax(-1) @@ -94,8 +94,8 @@ def masked_inplace_autoregression( input, ar_mask, seq_logproba, - temperature, - deterministic_synthesis, + logit_transformer=None, + deterministic_synthesis=False, forbidden_tokens=None, logit_biases=None, progress_bar_desc=None, @@ -127,7 +127,7 @@ def masked_inplace_autoregression( input=input, ar_mask=ar_mask, seq_logproba=seq_logproba, - temperature=temperature, + logit_transformer=logit_transformer, deterministic_synthesis=deterministic_synthesis, ) @@ -305,7 +305,6 @@ class QuizMachine: input=result, ar_mask=ar_mask, seq_logproba=seq_logproba, - temperature=1.0, deterministic_synthesis=deterministic_synthesis, progress_bar_desc="accuracy", device=self.device, @@ -447,15 +446,14 @@ class QuizMachine: ############################################################### - def solution_nb_correct(self, models_for_validation, c_quizzes): + def models_successes(self, models_for_validation, c_quizzes): seq_logproba = torch.zeros( c_quizzes.size(0), max([m.id for m in models_for_validation]) + 1, device=self.device, ) - nb_correct = 0 - correct_models = torch.empty( + correctly_solved = torch.empty( c_quizzes.size(0), max([m.id for m in models_for_validation]) + 1, device=self.device, @@ -477,14 +475,11 @@ class QuizMachine: input=result, ar_mask=ar_mask, seq_logproba=seq_logproba[:, model.id], - temperature=1.0, deterministic_synthesis=False, device=self.device, ) - correct_models[:, model.id] = ( - (c_quizzes == result).long().min(dim=-1).values - ) + correct = (c_quizzes == result).long().min(dim=-1).values # ------------------------------- @@ -502,22 +497,17 @@ class QuizMachine: input=result, ar_mask=ar_mask, seq_logproba=seq_logproba[:, model.id], - temperature=1.0, deterministic_synthesis=False, device=self.device, ) - correct_models[:, model.id] *= ( - (c_quizzes == result).long().min(dim=-1).values - ) + correct *= (c_quizzes == result).long().min(dim=-1).values # ------------------------------- - i = correct_models.sum(dim=1) == correct_models.size(1) - 1 - c = (correct_models[i] == 0).long().sum(dim=0) - self.logger(f"nb_failures_on_validated {tuple(x.item() for x in c)}") + correctly_solved[:, model.id] = correct - return correct_models.sum(dim=1).to("cpu") + return correctly_solved.to("cpu") ############################################################### @@ -538,6 +528,9 @@ class QuizMachine: seq_logproba = torch.zeros(nb, device=self.device) + def heater(T): + return lambda s, logits: logits / T + if p2a_only: c_quizzes[...] = self.problem.token_forward @@ -547,7 +540,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"), seq_logproba=seq_logproba, - temperature=temperature_hot, + logit_transformer=heater(temperature_hot), deterministic_synthesis=False, device=self.device, ) @@ -558,7 +551,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), seq_logproba=seq_logproba, - temperature=temperature_cold, + logit_transformer=heater(temperature_cold), deterministic_synthesis=False, device=self.device, ) @@ -572,7 +565,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"), seq_logproba=seq_logproba, - temperature=temperature_hot, + logit_transformer=heater(temperature_hot), deterministic_synthesis=False, device=self.device, ) @@ -583,7 +576,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), seq_logproba=seq_logproba, - temperature=temperature_cold, + logit_transformer=heater(temperature_cold), deterministic_synthesis=False, device=self.device, ) @@ -596,7 +589,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), seq_logproba=seq_logproba, - temperature=temperature_cold, + logit_transformer=heater(temperature_cold), deterministic_synthesis=False, device=self.device, ) -- 2.20.1