From cb2ca18632adce27777f6aa5e076b5f6318aab6e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 1 Aug 2024 10:09:33 +0200 Subject: [PATCH] Update. --- grids.py | 21 ++++++++++++++------- quiz_machine.py | 4 ++-- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/grids.py b/grids.py index 410b538..19b9ce4 100755 --- a/grids.py +++ b/grids.py @@ -167,16 +167,23 @@ class Grids(problem.Problem): self.check_structure(quizzes, struct) return struct - def inject_noise(self, quizzes, noise, struct, mask): - assert self.check_structure(quizzes, struct=struct) + def inject_noise(self, quizzes, noise, struct=None, mask=None): + # assert self.check_structure(quizzes, struct=struct) S = self.height * self.width - mask = torch.tensor(mask, device=quizzes.device) - mask = mask[None, :, None].expand(1, 4, S + 1) - mask[:, :, 0] = 0 + + # mask = torch.tensor(mask, device=quizzes.device) + # mask = mask[None, :, None].expand(1, 4, S + 1) + # mask[:, :, 0] = 0 + # mask = mask.reshape(1, -1).expand_as(quizzes) + + mask = quizzes.new_full(quizzes.size(), 1) + mask[:, 0 * (S + 1)] = 0 + mask[:, 1 * (S + 1)] = 0 + mask[:, 2 * (S + 1)] = 0 + mask[:, 3 * (S + 1)] = 0 mask = mask * (torch.rand(mask.size(), device=mask.device) <= noise).long() - mask = mask.reshape(1, -1).expand_as(quizzes) - random = torch.randint(self.nb_colors, mask.size()) + random = torch.randint(self.nb_colors, mask.size()) quizzes = mask * random + (1 - mask) * quizzes return quizzes diff --git a/quiz_machine.py b/quiz_machine.py index fb451f2..90ca5e6 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -181,8 +181,8 @@ class QuizMachine: quizzes = self.problem.inject_noise( quizzes, self.prompt_noise, - struct=("A", "f_A", "B", "f_B"), - mask=(1, 0, 1, 0), + # struct=("A", "f_A", "B", "f_B"), + # mask=(1, 0, 1, 0), ) self.randomize_configuations_inplace( -- 2.39.5