From: François Fleuret Date: Thu, 1 Aug 2024 07:12:49 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=6bdaaf9fb700637f595219cd27ca5e4a5c06af10;p=culture.git Update. --- diff --git a/grids.py b/grids.py index f12fcb9..410b538 100755 --- a/grids.py +++ b/grids.py @@ -172,6 +172,7 @@ class Grids(problem.Problem): S = self.height * self.width mask = torch.tensor(mask, device=quizzes.device) mask = mask[None, :, None].expand(1, 4, S + 1) + mask[:, :, 0] = 0 mask = mask * (torch.rand(mask.size(), device=mask.device) <= noise).long() mask = mask.reshape(1, -1).expand_as(quizzes) random = torch.randint(self.nb_colors, mask.size()) diff --git a/quiz_machine.py b/quiz_machine.py index 1e973c5..fb451f2 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -179,7 +179,10 @@ class QuizMachine: if self.prompt_noise > 0.0: quizzes = self.problem.inject_noise( - quizzes, self.prompt_noise, ("A", "f_A", "B", "f_B"), (1, 0, 1, 0) + quizzes, + self.prompt_noise, + struct=("A", "f_A", "B", "f_B"), + mask=(1, 0, 1, 0), ) self.randomize_configuations_inplace( @@ -304,10 +307,6 @@ class QuizMachine: model.train_w_quizzes.size(0) ) - self.randomize_configuations_inplace( - model.train_w_quizzes, structs=[s for s, m in self.understood_structures] - ) - ###################################################################### def store_c_quizzes(self, new_c_quizzes, for_train=True):