From 6bdaaf9fb700637f595219cd27ca5e4a5c06af10 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 1 Aug 2024 09:12:49 +0200 Subject: [PATCH] Update. --- grids.py | 1 + quiz_machine.py | 9 ++++----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/grids.py b/grids.py index f12fcb9..410b538 100755 --- a/grids.py +++ b/grids.py @@ -172,6 +172,7 @@ class Grids(problem.Problem): S = self.height * self.width mask = torch.tensor(mask, device=quizzes.device) mask = mask[None, :, None].expand(1, 4, S + 1) + mask[:, :, 0] = 0 mask = mask * (torch.rand(mask.size(), device=mask.device) <= noise).long() mask = mask.reshape(1, -1).expand_as(quizzes) random = torch.randint(self.nb_colors, mask.size()) diff --git a/quiz_machine.py b/quiz_machine.py index 1e973c5..fb451f2 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -179,7 +179,10 @@ class QuizMachine: if self.prompt_noise > 0.0: quizzes = self.problem.inject_noise( - quizzes, self.prompt_noise, ("A", "f_A", "B", "f_B"), (1, 0, 1, 0) + quizzes, + self.prompt_noise, + struct=("A", "f_A", "B", "f_B"), + mask=(1, 0, 1, 0), ) self.randomize_configuations_inplace( @@ -304,10 +307,6 @@ class QuizMachine: model.train_w_quizzes.size(0) ) - self.randomize_configuations_inplace( - model.train_w_quizzes, structs=[s for s, m in self.understood_structures] - ) - ###################################################################### def store_c_quizzes(self, new_c_quizzes, for_train=True): -- 2.39.5