From 17cda95f9b478e0919ab4c6122a66a3a58cd904f Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 23 Aug 2024 07:53:54 +0200 Subject: [PATCH] Update. --- main.py | 5 +++-- quiz_machine.py | 42 ++++++++++++++++++++++-------------------- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/main.py b/main.py index 00a6cd1..cd78959 100755 --- a/main.py +++ b/main.py @@ -969,6 +969,7 @@ def test_ae(local_device=main_device): targets = input input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA + result = (1 - mask_generate) * input + mask_generate * torch.randint( quiz_machine.problem.nb_colors, input.size(), device=input.device ) @@ -985,8 +986,8 @@ def test_ae(local_device=main_device): result[not_converged] = update[not_converged] not_converged = (pred_result != result).max(dim=1).values nb_it += 1 - print("DEBUG", nb_it, i.long().sum().item()) - if not i.any() or nb_it > 100: + print("DEBUG", nb_it, not_converged.long().sum().item()) + if not not_converged.any() or nb_it > 100: break correct = (result == targets).min(dim=1).values.long() diff --git a/quiz_machine.py b/quiz_machine.py index 0f13964..af24c92 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -81,7 +81,7 @@ class QuizMachine: self.answer_len = None self.prompt_noise = prompt_noise - # struct, quad_generate, quad_noise, quad_loss + # quad_order, quad_generate, quad_noise, quad_loss self.train_structures = [ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), @@ -171,32 +171,34 @@ class QuizMachine: else: quizzes = self.problem.generate_w_quizzes(nb_samples) + # shuffle + i = torch.randperm(quizzes.size(0), device=quizzes.device) quizzes = quizzes[i] - self.randomize_configuations_inplace( - quizzes, quad_orders=[s for s, _, _, _ in data_structures] - ) + # Re-order and inject noise quiz_mask_generate = quizzes.new_full(quizzes.size(), 1) quiz_mask_loss = quizzes.new_full(quizzes.size(), 1) - - for quad_order, quad_generate, quad_noise, quad_loss in data_structures: - i = self.problem.indices_select(quizzes=quizzes, quad_order=quad_order) - if i.any(): - if self.prompt_noise > 0.0: - quizzes[i] = self.problem.inject_noise( - quizzes[i], - self.prompt_noise, - quad_order=quad_order, - quad_noise=quad_noise, - ) - quiz_mask_generate[i] = self.make_quiz_mask( - quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_generate - ) - quiz_mask_loss[i] = self.make_quiz_mask( - quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_loss + order_ids = torch.randint(len(data_structures), (quizzes.size(0),)) + + for j, s in enumerate(data_structures): + quad_order, quad_generate, quad_noise, quad_loss = s + i = order_ids == j + quizzes[i] = self.problem.reconfigure(quizzes[i], quad_order=quad_order) + if self.prompt_noise > 0.0: + quizzes[i] = self.problem.inject_noise( + quizzes[i], + self.prompt_noise, + quad_order=quad_order, + quad_noise=quad_noise, ) + quiz_mask_generate[i] = self.make_quiz_mask( + quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_generate + ) + quiz_mask_loss[i] = self.make_quiz_mask( + quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_loss + ) return quizzes, quiz_mask_generate, quiz_mask_loss -- 2.39.5