From 7ba095bd41c4f6fc48cf49046af2d5ab4e3b1e39 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 1 Aug 2024 10:18:14 +0200 Subject: [PATCH] Update. --- grids.py | 19 +++++++------------ quiz_machine.py | 18 ++++++++++-------- 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/grids.py b/grids.py index 9b1ed50..f747c47 100755 --- a/grids.py +++ b/grids.py @@ -167,23 +167,18 @@ class Grids(problem.Problem): self.check_structure(quizzes, struct) return struct - def inject_noise(self, quizzes, noise): # , struct=None, mask=None): - # assert self.check_structure(quizzes, struct=struct) + def inject_noise(self, quizzes, noise, struct, mask): + assert self.check_structure(quizzes, struct=struct) S = self.height * self.width - # mask = torch.tensor(mask, device=quizzes.device) - # mask = mask[None, :, None].expand(1, 4, S + 1) - # mask[:, :, 0] = 0 - # mask = mask.reshape(1, -1).expand_as(quizzes) + mask = torch.tensor(mask, device=quizzes.device) + mask = mask[None, :, None].expand(1, 4, S + 1) + mask[:, :, 0] = 0 + mask = mask.reshape(1, -1).expand_as(quizzes) - mask = quizzes.new_full(quizzes.size(), 1) - mask[:, 0 * (S + 1)] = 0 - mask[:, 1 * (S + 1)] = 0 - mask[:, 2 * (S + 1)] = 0 - mask[:, 3 * (S + 1)] = 0 mask = mask * (torch.rand(mask.size(), device=mask.device) <= noise).long() - random = torch.randint(self.nb_colors, mask.size()) + quizzes = mask * random + (1 - mask) * quizzes return quizzes diff --git a/quiz_machine.py b/quiz_machine.py index 90ca5e6..a4ca60b 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -177,18 +177,20 @@ class QuizMachine: i = torch.randperm(quizzes.size(0), device=quizzes.device) quizzes, from_w = quizzes[i], from_w[i] - if self.prompt_noise > 0.0: - quizzes = self.problem.inject_noise( - quizzes, - self.prompt_noise, - # struct=("A", "f_A", "B", "f_B"), - # mask=(1, 0, 1, 0), - ) - self.randomize_configuations_inplace( quizzes, structs=[s for s, m in self.understood_structures] ) + if self.prompt_noise > 0.0: + for struct, mask in self.understood_structures: + i = self.problem.indices_select(quizzes=input, struct=struct) + input[i] = self.problem.inject_noise( + input[i], + self.prompt_noise, + struct=struct, + mask=(1 - k for k in mask), + ) + return quizzes, from_w ###################################################################### -- 2.20.1