From 818ad3f2215e5b811c07696fb4cd1cb7012e8e53 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 8 Aug 2024 08:21:23 +0200 Subject: [PATCH] Update. --- main.py | 72 +++++++++++++++++++------------------------------ quiz_machine.py | 39 ++++++++++++--------------- 2 files changed, 45 insertions(+), 66 deletions(-) diff --git a/main.py b/main.py index 8f3568f..86eafea 100755 --- a/main.py +++ b/main.py @@ -395,8 +395,6 @@ def run_tests(model, quiz_machine, local_device=main_device): def one_epoch(model, quiz_machine, local_device=main_device): model.to(local_device).train() - optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) - nb_train_samples, acc_train_loss = 0, 0.0 hard_w_quizzes = [] @@ -413,7 +411,7 @@ def one_epoch(model, quiz_machine, local_device=main_device): input = input.to(local_device) if nb_train_samples % args.batch_size == 0: - optimizer.zero_grad() + model.optimizer.zero_grad() targets = input @@ -435,7 +433,7 @@ def one_epoch(model, quiz_machine, local_device=main_device): loss.backward() if nb_train_samples % args.batch_size == 0: - optimizer.step() + model.optimizer.step() train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) @@ -470,6 +468,7 @@ c_quizzes_procedure = [ (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot), (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold), (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold), + (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold), # (("f_B", "f_A", "A", "B"), (0, 0, 1, 1), model_transformer_cold), # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold), ] @@ -489,22 +488,15 @@ def save_additional_results(model, models, science_w_quizzes): recorder=recorder, ) - ## - - probas = 0 + # This is nb_quizzes x nb_models - for a in range(args.nb_averaging_rounds): - # This is nb_quizzes x nb_models - - seq_logproba = quiz_machine.models_logprobas( - models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) + quiz_machine.models_logprobas( - models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) - - probas += seq_logproba.exp() + seq_logproba = quiz_machine.models_logprobas( + models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + quiz_machine.models_logprobas( + models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) - probas /= args.nb_averaging_rounds + probas = seq_logproba.exp() comments = [] @@ -597,8 +589,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 nb_validated_per_model = torch.zeros(len(models), dtype=torch.int64) - to_recycle = None - while nb_validated_per_model.sum() < nb_to_validate: # We use the model that has generated the fewest quizzes to # balance the number of quizzes per model overall @@ -616,35 +606,24 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 nb_to_generate_per_iteration, model_for_generation=model, procedure=c_quizzes_procedure, - to_recycle=to_recycle, ) # We discard the trivial ones, according to a criterion # specific to the world quizzes (e.g. B=f(B)) - rejected = [] - to_keep = quiz_machine.problem.trivial(c_quizzes) == False - if not to_keep.all(): - rejected.append(c_quizzes[to_keep == False]) - c_quizzes = c_quizzes[to_keep] - probas = 0 - - for a in range(args.nb_averaging_rounds): - # This is nb_quizzes x nb_models - - seq_logproba = quiz_machine.models_logprobas( - models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) + quiz_machine.models_logprobas( - models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) + # This is nb_quizzes x nb_models - probas += seq_logproba.exp() + seq_logproba = quiz_machine.models_logprobas( + models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + quiz_machine.models_logprobas( + models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) - probas /= args.nb_averaging_rounds + probas = seq_logproba.exp() nb_succeed = (probas >= args.proba_understands).long().sum(dim=1) nb_fail = (probas <= args.proba_not_understands).long().sum(dim=1) @@ -655,7 +634,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 & (nb_fail <= args.max_fail_to_validate) ) - to_recycle = c_quizzes[to_keep == False] c_quizzes = c_quizzes[to_keep] if c_quizzes.size(0) > 0: @@ -1010,7 +988,6 @@ def train_complexifier(model_gen, model_pred1, model_pred2): ###################################################################### - models = [] for k in range(args.nb_gpts): @@ -1027,9 +1004,11 @@ for k in range(args.nb_gpts): dropout=args.dropout, ).to(main_device) - model.main_test_accuracy = 0.0 model.id = k + model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + model.main_test_accuracy = 0.0 + model.train_w_quizzes = quiz_machine.problem.generate_w_quizzes( args.nb_train_samples ) @@ -1048,8 +1027,9 @@ if args.resume: try: d = torch.load(os.path.join(args.result_dir, filename)) - model.load_state_dict(d[0]) - model.main_test_accuracy = d[1] + model.load_state_dict(d["state_dict"]) + model.optimizer.load_state_dict(d["optimizer_state_dict"]) + model.main_test_accuracy = d["main_test_accuracy"] log_string(f"successfully loaded {filename}") except FileNotFoundError: log_string(f"cannot find {filename}") @@ -1305,7 +1285,11 @@ for n_epoch in range(current_epoch, args.nb_epochs): for model in weakest_models: filename = f"gpt_{model.id:03d}.pth" torch.save( - (model.state_dict(), model.main_test_accuracy), + { + "state_dict": model.state_dict(), + "optimizer_state_dict": model.optimizer.state_dict(), + "main_test_accuracy": model.main_test_accuracy, + }, os.path.join(args.result_dir, filename), ) log_string(f"wrote {filename}") diff --git a/quiz_machine.py b/quiz_machine.py index 3fc1066..daa9bbf 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -81,6 +81,7 @@ class QuizMachine: self.answer_len = None self.prompt_noise = prompt_noise + # struct, mask_generate, mask_noise self.understood_structures = [ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)), (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)), @@ -178,15 +179,15 @@ class QuizMachine: quizzes, from_w = quizzes[i], from_w[i] self.randomize_configuations_inplace( - quizzes, structs=[s for s, m, _ in self.understood_structures] + quizzes, structs=[s for s, _, _ in self.understood_structures] ) if self.prompt_noise > 0.0: - for struct, mask, noise_mask in self.understood_structures: + for struct, _, mask_noise in self.understood_structures: i = self.problem.indices_select(quizzes=quizzes, struct=struct) if i.any(): quizzes[i] = self.problem.inject_noise( - quizzes[i], self.prompt_noise, struct=struct, mask=noise_mask + quizzes[i], self.prompt_noise, struct=struct, mask=mask_noise ) return quizzes, from_w @@ -228,13 +229,15 @@ class QuizMachine: nb = 0 # We consider all the configurations that we train for - for struct, mask, _ in self.understood_structures: + for struct, mask_generate, _ in self.understood_structures: i = self.problem.indices_select(quizzes=input, struct=struct) nb += i.long().sum() result[i], correct[i] = self.predict( - model=model, quizzes=input[i], struct=struct, mask=mask + model=model, quizzes=input[i], struct=struct, mask=mask_generate ) - predicted_parts[i] = torch.tensor(mask, device=self.device)[None, :] + predicted_parts[i] = torch.tensor(mask_generate, device=self.device)[ + None, : + ] solution_is_deterministic = predicted_parts[i].sum(dim=-1) == 1 correct[i] = (2 * correct[i] - 1) * (solution_is_deterministic).long() @@ -329,8 +332,8 @@ class QuizMachine: models_for_validation, c_quizzes, struct, - mask, - noise_mask=None, + mask_value, + mask_noise=None, device=None, ): if device is None: @@ -344,10 +347,10 @@ class QuizMachine: device=device, ) - if self.prompt_noise > 0.0 and noise_mask is not None: - c_quizzes = self.problem.inject_noise( - c_quizzes, self.prompt_noise, struct=struct, mask=noise_mask - ) + # if self.prompt_noise > 0.0 and mask_noise is not None: + # c_quizzes = self.problem.inject_noise( + # c_quizzes, self.prompt_noise, struct=struct, mask=mask_noise + # ) for model in models_for_validation: with torch.autograd.no_grad(): @@ -359,7 +362,7 @@ class QuizMachine: seq_logproba.split(self.batch_size), ): input = input.to(device) - ar_mask = self.make_ar_mask(input, struct=struct, mask=mask) + ar_mask = self.make_ar_mask(input, struct=struct, mask=mask_value) output = model(mygpt.BracketedSequence(input)).x l[:, model.id] = ( -F.cross_entropy( @@ -374,9 +377,7 @@ class QuizMachine: ###################################################################### - def generate_c_quizzes( - self, nb, model_for_generation, procedure, to_recycle=None, recorder=None - ): + def generate_c_quizzes(self, nb, model_for_generation, procedure, recorder=None): seq_logproba = torch.zeros(nb, device=self.device) c_quizzes = None @@ -408,12 +409,6 @@ class QuizMachine: self.problem.reconfigure([x, t], ("A", "f_A", "B", "f_B")) ) - if to_recycle is not None and to_recycle.size(0) > 0: - to_recycle = self.problem.reconfigure(to_recycle, s) - c_quizzes[: to_recycle.size(0)] = to_recycle - - to_recycle = None - c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B")) return c_quizzes.to("cpu") -- 2.39.5