From 0e55fc1591ca4f5675d1b9ce543f768c41b9a384 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 12 Aug 2024 07:08:22 +0200 Subject: [PATCH] Update. --- main.py | 39 ++++++++++++------------- quiz_machine.py | 77 +++++++------------------------------------------ 2 files changed, 29 insertions(+), 87 deletions(-) diff --git a/main.py b/main.py index cd6e3a9..f51ab38 100755 --- a/main.py +++ b/main.py @@ -390,7 +390,9 @@ def run_tests(model, quiz_machine, local_device=main_device): nb_test_samples, acc_test_loss = 0, 0.0 nb_samples_accumulated = 0 - full_input, full_mask_loss = quiz_machine.data_input(model, split="test") + full_input, full_mask_loss = quiz_machine.data_input( + model, args.nb_test_samples + ) src = zip( full_input.split(args.batch_size), full_mask_loss.split(args.batch_size) ) @@ -439,7 +441,7 @@ def one_epoch(model, quiz_machine, local_device=main_device): hard_w_quizzes = [] - full_input, full_mask_loss = quiz_machine.data_input(model, split="train") + full_input, full_mask_loss = quiz_machine.data_input(model, args.nb_train_samples) src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)) for input, mask_loss in tqdm.tqdm( @@ -626,13 +628,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 nb_validated_per_model = torch.zeros(len(models), dtype=torch.int64) while nb_validated_per_model.sum() < nb_to_validate: - # We use the model that has generated the fewest quizzes to - # balance the number of quizzes per model overall - - # model_for_generation = sorted( - # models, key=lambda m: nb_validated_per_model[m.id] - # )[0] - model_for_generation = models[torch.randint(len(models), (1,)).item()] # We generate quizzes with a procedure that injects some @@ -653,6 +648,18 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 # This is nb_quizzes x nb_models + solved_c_quizzes = c_quizzes[:, None, :].expand(-1, len(models), -1).clone() + + for m in models: + solved_c_quizzes[:, m.id] = quiz_machine.predict( + m, + solved_c_quizzes[:, m.id], + struct=("A", "f_A", "B", "f_B"), + mask=(0, 0, 0, 1), + ) + + # FINISH + seq_logproba = quiz_machine.models_logprobas( models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) ) + quiz_machine.models_logprobas( @@ -1043,6 +1050,7 @@ for k in range(args.nb_gpts): ).to(main_device) model.id = k + model.c_quiz_bags = [] if args.schedule_free: model.optimizer = schedulefree.AdamWScheduleFree( @@ -1053,12 +1061,6 @@ for k in range(args.nb_gpts): model.main_test_accuracy = 0.0 - model.train_w_quizzes = quiz_machine.problem.generate_w_quizzes( - args.nb_train_samples - ) - - model.test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples) - models.append(model) ###################################################################### @@ -1312,7 +1314,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): record_new_c_quizzes( models, quiz_machine, - nb_for_train=args.nb_new_c_quizzes_for_train, + nb_errorsfor_train=args.nb_new_c_quizzes_for_train, nb_for_test=args.nb_new_c_quizzes_for_test, ) @@ -1366,11 +1368,6 @@ for n_epoch in range(current_epoch, args.nb_epochs): ###################################################################### - # Renew the training samples - - for model in weakest_models: - quiz_machine.renew_train_w_quizzes(model=model) - if args.log_command is not None: s = args.log_command.split() s.insert(1, args.result_dir) diff --git a/quiz_machine.py b/quiz_machine.py index 92da03d..1d89cf4 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -97,10 +97,6 @@ class QuizMachine: self.test_structures = self.train_structures - self.LOCK_C_QUIZZES = threading.Lock() - self.train_c_quizzes = [] - self.test_c_quizzes = [] - def vocabulary_size(self): return self.problem.nb_token_values @@ -150,40 +146,21 @@ class QuizMachine: ###################################################################### - def data_input(self, model, split="train"): - assert split in {"train", "test"} - - with self.LOCK_C_QUIZZES: - if split == "train": - w_quizzes = model.train_w_quizzes - c_quizzes = self.train_c_quizzes - else: - w_quizzes = model.test_w_quizzes - c_quizzes = self.test_c_quizzes - - if len(c_quizzes) > 0: - c_quizzes = torch.cat(c_quizzes, dim=0) + def data_input(self, model, nb_samples): + if len(model.c_quiz_bags) > 0: + c_quizzes = torch.cat(model.c_quiz_bags, dim=0) - if c_quizzes.size(0) > w_quizzes.size(0) // 2: - i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2] - c_quizzes = c_quizzes[i] + if c_quizzes.size(0) > nb_samples // 2: + i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2] + c_quizzes = c_quizzes[i] - i = torch.randperm(w_quizzes.size(0))[ - : w_quizzes.size(0) - c_quizzes.size(0) - ] - w_quizzes = w_quizzes[i] - - quizzes = torch.cat([w_quizzes, c_quizzes], dim=0) - from_w = torch.arange( - quizzes.size(0), device=quizzes.device - ) < w_quizzes.size(0) - - else: - quizzes = w_quizzes.clone() - from_w = torch.full((quizzes.size(0),), True, device=quizzes.device) + w_quizzes = self.problem.generate_w_quizzes(nb_samples - c_quizzes.size(0)) + quizzes = torch.cat([w_quizzes, c_quizzes], dim=0) + else: + quizzes = self.problem.generate_w_quizzes(nb_samples) i = torch.randperm(quizzes.size(0), device=quizzes.device) - quizzes, from_w = quizzes[i], from_w[i] + quizzes = quizzes[i] self.randomize_configuations_inplace( quizzes, structs=[s for s, _, _, _ in self.train_structures] @@ -292,38 +269,6 @@ class QuizMachine: ###################################################################### - def renew_train_w_quizzes(self, model): - if hasattr(model, "hard_w_quizzes"): - hard_w_quizzes = self.problem.reconfigure( - model.hard_w_quizzes, struct=("A", "f_A", "B", "f_B") - ) - self.logger( - f"re-using {hard_w_quizzes.size(0)} hard world quizzes from model {model.id}" - ) - if hard_w_quizzes.size(0) >= model.train_w_quizzes.size(0): - nb_to_generate = 0 - model.train_w_quizzes[...] = hard_w_quizzes[ - torch.randperm(hard_w_quizzes.size(0))[ - model.train_w_quizzes.size(0) - ] - ] - else: - nb_to_generate = model.train_w_quizzes.size(0) - hard_w_quizzes.size(0) - model.train_w_quizzes[...] = torch.cat( - [ - hard_w_quizzes, - self.problem.generate_w_quizzes(nb_to_generate), - ], - dim=0, - ) - else: - nb_to_generate = 0 - model.train_w_quizzes[...] = self.problem.generate_w_quizzes( - model.train_w_quizzes.size(0) - ) - - ###################################################################### - def store_c_quizzes(self, new_c_quizzes, for_train=True): with self.LOCK_C_QUIZZES: if for_train: -- 2.39.5