From b487ed456981e8b52b1bcff1f7e9058270153c26 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 21 Aug 2024 17:08:59 +0200 Subject: [PATCH] Update. --- main.py | 230 ++++++++++++++++++++++++-------------------------------- 1 file changed, 97 insertions(+), 133 deletions(-) diff --git a/main.py b/main.py index d98031e..65695af 100755 --- a/main.py +++ b/main.py @@ -65,7 +65,7 @@ parser.add_argument("--c_quiz_multiplier", type=int, default=1) parser.add_argument("--learning_rate", type=float, default=5e-4) -parser.add_argument("--lambda_H", type=float, default=0.0) +parser.add_argument("--reboot", action="store_true", default=False) parser.add_argument("--schedule_free", action="store_true", default=False) @@ -395,7 +395,9 @@ def one_epoch(model, quiz_machine, local_device=main_device): nb_train_samples, acc_train_loss = 0, 0.0 full_input, full_mask_loss = quiz_machine.data_input( - args.nb_train_samples, model.train_c_quiz_bags, args.c_quiz_multiplier + args.nb_train_samples, + model.train_c_quiz_bags + common_c_quiz_bags, + args.c_quiz_multiplier, ) src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)) @@ -472,76 +474,6 @@ c_quizzes_procedure = [ ###################################################################### -def save_additional_results(n_epoch, model, models, c_quizzes_procedure): - # Save generated quizzes with the successive generation steps - - recorder = [] - - c_quizzes = quiz_machine.generate_c_quizzes( - 64, - model_for_generation=model, - procedure=c_quizzes_procedure, - recorder=recorder, - ) - - # This is nb_quizzes x nb_models - - l = [ - quiz_machine.models_logprobas( - model, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) - + quiz_machine.models_logprobas( - model, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) - + quiz_machine.models_logprobas( - model, c_quizzes, ("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0) - ) - + quiz_machine.models_logprobas( - model, c_quizzes, ("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0) - ) - for model in models - ] - - seq_logprobas = torch.cat([x[:, None] for x in l], dim=1) - probas = seq_logprobas.exp() - - comments = [] - - for l in seq_logprobas: - comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l])) - - ## - - c_quizzes = torch.cat([c[:, None, :] for c, _, in recorder], dim=1) - predicted_parts = torch.cat([t[:, None, :] for _, t in recorder], dim=1) - nb_steps = c_quizzes.size(1) - c_quizzes = c_quizzes.reshape(-1, c_quizzes.size(-1)) - predicted_parts = predicted_parts.reshape(-1, predicted_parts.size(-1)) - - # We have comments only for the final quiz, not the successive - # steps, so we have to add nb_steps-1 empty comments - - steps_comments = [] - for c in comments: - steps_comments += [""] * (nb_steps - 1) + [c] - - filename = f"non_validated_{n_epoch:04d}_{model.id:02d}.png" - - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - filename, - quizzes=c_quizzes, - predicted_parts=predicted_parts, - comments=steps_comments, - nrow=nb_steps * 2, # two quiz per row - ) - - log_string(f"wrote {filename}") - - -###################################################################### - - def model_proba_solutions(model, quizzes): l = ( quiz_machine.models_logprobas( @@ -561,6 +493,9 @@ def model_proba_solutions(model, quizzes): return l.exp() +###################################################################### + + def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): nb_validated, nb_to_validate = 0, (nb_for_train + nb_for_test) * len(models) nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate @@ -804,60 +739,69 @@ class Thinker(nn.Module): ###################################################################### -models = [] - - -def compute_causal_attzero(t_q, t_k): - return t_q < t_k +def create_models(): + models = [] -if args.schedule_free: - import schedulefree + def compute_causal_attzero(t_q, t_k): + return t_q < t_k -for k in range(args.nb_gpts): - log_string(f"creating model {k}") + if args.schedule_free: + import schedulefree + + for k in range(args.nb_gpts): + log_string(f"creating model {k}") + + model = mygpt.MyGPT( + vocabulary_size=vocabulary_size, + dim_model=args.dim_model, + dim_keys=args.dim_keys, + dim_hidden=args.dim_hidden, + nb_heads=args.nb_heads, + nb_blocks=args.nb_blocks, + compute_attzero=compute_causal_attzero, + dropout=args.dropout, + ).to(main_device) + + class UpperBoundStd(nn.Module): + def __init__(self, std_max=1.0): + super().__init__() + self.std_max = std_max + + def forward(self, x): + std = x.std(dim=-1, keepdim=True) + y = (x - x.mean(dim=-1, keepdim=True)) / std.clamp(max=self.std_max) + return y + + if args.logit_std_max > 0: + model.readout.f = nn.Sequential( + model.readout.f, UpperBoundStd(std_max=args.logit_std_max) + ) - model = mygpt.MyGPT( - vocabulary_size=vocabulary_size, - dim_model=args.dim_model, - dim_keys=args.dim_keys, - dim_hidden=args.dim_hidden, - nb_heads=args.nb_heads, - nb_blocks=args.nb_blocks, - compute_attzero=compute_causal_attzero, - dropout=args.dropout, - ).to(main_device) + model.id = k + model.train_c_quiz_bags = [] + model.test_c_quiz_bags = [] - class UpperBoundStd(nn.Module): - def __init__(self, std_max=1.0): - super().__init__() - self.std_max = std_max + if args.schedule_free: + model.optimizer = schedulefree.AdamWScheduleFree( + model.parameters(), lr=args.learning_rate + ) + else: + model.optimizer = torch.optim.Adam( + model.parameters(), lr=args.learning_rate + ) - def forward(self, x): - std = x.std(dim=-1, keepdim=True) - y = (x - x.mean(dim=-1, keepdim=True)) / std.clamp(max=self.std_max) - return y + model.test_accuracy = 0.0 + model.gen_test_accuracy = 0.0 + model.gen_state_dict = copy.deepcopy(model.state_dict()) + models.append(model) - if args.logit_std_max > 0: - model.readout.f = nn.Sequential( - model.readout.f, UpperBoundStd(std_max=args.logit_std_max) - ) + return models - model.id = k - model.train_c_quiz_bags = [] - model.test_c_quiz_bags = [] - if args.schedule_free: - model.optimizer = schedulefree.AdamWScheduleFree( - model.parameters(), lr=args.learning_rate - ) - else: - model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) +common_c_quiz_bags = [] - model.test_accuracy = 0.0 - model.best_test_accuracy = 0.0 - model.best_dict = copy.deepcopy(model.state_dict()) - models.append(model) +models = create_models() ###################################################################### @@ -878,8 +822,8 @@ if args.resume: model.load_state_dict(d["state_dict"]) model.optimizer.load_state_dict(d["optimizer_state_dict"]) model.test_accuracy = d["test_accuracy"] - model.best_test_accuracy = d["best_test_accuracy"] - model.best_dict = d["best_dict"] + model.gen_test_accuracy = d["gen_test_accuracy"] + model.gen_state_dict = d["gen_state_dict"] model.train_c_quiz_bags = d["train_c_quiz_bags"] model.test_c_quiz_bags = d["test_c_quiz_bags"] log_string(f"successfully loaded {filename}") @@ -906,10 +850,10 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") ###################################################################### if args.nb_new_c_quizzes_for_train is None: - args.nb_new_c_quizzes_for_train = args.nb_train_samples // 250 + args.nb_new_c_quizzes_for_train = args.nb_train_samples // 50 if args.nb_new_c_quizzes_for_test is None: - args.nb_new_c_quizzes_for_test = args.nb_test_samples // 250 + args.nb_new_c_quizzes_for_test = args.nb_test_samples // 50 log_string( f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}" @@ -1128,41 +1072,64 @@ for n_epoch in range(current_epoch, args.nb_epochs): cta = " ".join([f"{float(m.test_accuracy):.04f}" for m in models]) log_string(f"current_test_accuracies {cta}") - cta = " ".join([f"{float(m.best_test_accuracy):.04f}" for m in models]) - log_string(f"current_best_test_accuracies {cta}") + cta = " ".join([f"{float(m.gen_test_accuracy):.04f}" for m in models]) + log_string(f"current_gen_test_accuracies {cta}") ################################################## for model in models: if model.test_accuracy >= args.accuracy_to_make_c_quizzes: log_string( - f"storing_best model {model.id} accuracy {model.best_test_accuracy} -> {model.test_accuracy}" + f"storing_gen model {model.id} accuracy {model.gen_test_accuracy} -> {model.test_accuracy}" ) - model.best_dict = copy.deepcopy(model.state_dict()) - model.best_test_accuracy = model.test_accuracy + model.gen_state_dict = copy.deepcopy(model.state_dict()) + model.gen_test_accuracy = model.test_accuracy # we restart if total_time_generating_c_quizzes == 0: total_time_training_models = 0 if ( - min([m.best_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes + min([m.gen_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes and total_time_training_models >= total_time_generating_c_quizzes ): + ###################################################################### + # Re-initalize if there are enough culture quizzes + + if args.reboot: + nb_c_quizzes_per_model = [ + sum([x.size(0) for x in model.train_c_quiz_bags]) for model in models + ] + + m = max(nb_c_quizzes_per_model) + + if m >= args.nb_train_samples: + model = models[nb_c_quizzes_per_model.index(m)] + common_c_quiz_bags.append(torch.cat(model.train_c_quiz_bags, dim=0)) + nb_common_c_quizzes = sum([x.size(0) for x in common_c_quiz_bags]) + log_string( + f"rebooting the models with {nb_common_c_quizzes} culture quizzes" + ) + + models = create_models() + total_time_generating_c_quizzes = 0 + total_time_training_models = 0 + for model in models: model.current_dict = copy.deepcopy(model.state_dict()) - model.load_state_dict(model.best_dict) + model.load_state_dict(model.gen_state_dict) start_time = time.perf_counter() + record_new_c_quizzes( models, quiz_machine, args.nb_new_c_quizzes_for_train, args.nb_new_c_quizzes_for_test, ) + total_time_generating_c_quizzes += time.perf_counter() - start_time - # Force one epoch of training for model in models: model.load_state_dict(model.current_dict) @@ -1204,9 +1171,6 @@ for n_epoch in range(current_epoch, args.nb_epochs): total_time_training_models += time.perf_counter() - start_time - for model in weakest_models: - save_additional_results(n_epoch, model, models, c_quizzes_procedure) - # Save the models to disk for model in models: @@ -1216,8 +1180,8 @@ for n_epoch in range(current_epoch, args.nb_epochs): "state_dict": model.state_dict(), "optimizer_state_dict": model.optimizer.state_dict(), "test_accuracy": model.test_accuracy, - "best_test_accuracy": model.best_test_accuracy, - "best_dict": model.best_dict, + "gen_test_accuracy": model.gen_test_accuracy, + "gen_state_dict": model.gen_state_dict, "train_c_quiz_bags": model.train_c_quiz_bags, "test_c_quiz_bags": model.test_c_quiz_bags, }, -- 2.39.5