From 1be4097518312d6a513c583b541ed4b5c425eb5f Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 24 Jul 2024 20:37:01 +0200 Subject: [PATCH] Update. --- grids.py | 29 +++++++++++++++++++---------- quiz_machine.py | 46 +++++++++++++++++++++++++++++----------------- 2 files changed, 48 insertions(+), 27 deletions(-) diff --git a/grids.py b/grids.py index 5ddcf32..80b8b1d 100755 --- a/grids.py +++ b/grids.py @@ -139,18 +139,24 @@ class Grids(problem.Problem): def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")): S = self.height * self.width + result = quizzes.new(quizzes.size()) + + struct_from = self.get_structure(quizzes[:1]) + i = self.indices_select(quizzes, struct_from) - struct_from = self.get_structure(quizzes) sf = dict((l, n) for n, l in enumerate(struct_from)) - result = quizzes.new(quizzes.size()) - q = quizzes.reshape(quizzes.size(0), 4, S + 1) - r = result.reshape(result.size(0), 4, S + 1) + q = quizzes.reshape(-1, 4, S + 1)[i] + + result[i, 0 * (S + 1) : 1 * (S + 1)] = q[:, sf[struct[0]], :] + result[i, 1 * (S + 1) : 2 * (S + 1)] = q[:, sf[struct[1]], :] + result[i, 2 * (S + 1) : 3 * (S + 1)] = q[:, sf[struct[2]], :] + result[i, 3 * (S + 1) : 4 * (S + 1)] = q[:, sf[struct[3]], :] + + j = i == False - r[:, 0] = q[:, sf[struct[0]], :] - r[:, 1] = q[:, sf[struct[1]], :] - r[:, 2] = q[:, sf[struct[2]], :] - r[:, 3] = q[:, sf[struct[3]], :] + if j.any(): + result[j] = self.reconfigure(quizzes[j], struct=struct) return result @@ -258,8 +264,8 @@ class Grids(problem.Problem): y = y[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale) y = y.reshape(s[0], s[1], s[2] * scale, s[3] * scale) - y[:, :, :, torch.arange(0, y.size(3), scale)] = 0 - y[:, :, torch.arange(0, y.size(2), scale), :] = 0 + y[:, :, :, torch.arange(0, y.size(3), scale)] = 224 + y[:, :, torch.arange(0, y.size(2), scale), :] = 224 for n in range(m.size(0)): for i in range(m.size(1)): @@ -1446,9 +1452,12 @@ if __name__ == "__main__": nb = 5 quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill]) + print(quizzes) print(grids.get_structure(quizzes)) quizzes = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B")) + print("DEBUG2", quizzes) print(grids.get_structure(quizzes)) + print(quizzes) i = torch.rand(quizzes.size(0)) < 0.5 diff --git a/quiz_machine.py b/quiz_machine.py index 2fb196c..a384377 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -131,6 +131,12 @@ class QuizMachine: self.prompt_len = None self.answer_len = None + self.configurations = [ + ("A", "f_A", "B", "f_B"), # The standard order + ("f_A", "A", "f_B", "B"), # The reverse order for validation + ("f_B", "f_A", "A", "B"), # The synthesis order + ] + self.LOCK_C_QUIZZES = threading.Lock() self.train_c_quizzes = [] self.test_c_quizzes = [] @@ -212,7 +218,8 @@ class QuizMachine: nb = 0 for struct, mask in [ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1)), - (("f_B", "f_A", "B", "A"), (0, 1, 1, 1)), + (("f_A", "A", "f_B", "B"), (0, 0, 0, 1)), + (("f_B", "f_A", "A", "B"), (0, 1, 1, 1)), ]: i = self.problem.indices_select(quizzes=input, struct=struct) nb += i.long().sum() @@ -220,6 +227,7 @@ class QuizMachine: model=model, quizzes=input[i], struct=struct, mask=mask ) + print(f"{nb=} {input.size(0)=}") assert nb == input.size(0) main_test_accuracy = correct.sum() / correct.size(0) @@ -236,27 +244,29 @@ class QuizMachine: ###################################################################### - def flip_half_in_place(self, quizzes): - r = torch.rand(quizzes.size(0), device=quizzes.device) < 0.5 - i = self.problem.indices_select( - quizzes=quizzes, struct=("A", "f_A", "B", "f_B") - ) - quizzes[i & r] = self.problem.reconfigure( - quizzes[i & r], struct=("f_B", "f_A", "B", "A") - ) - j = self.problem.indices_select( - quizzes=quizzes, struct=("f_B", "f_A", "B", "A") - ) - quizzes[j & r] = self.problem.reconfigure( - quizzes[j & r], struct=("A", "f_A", "B", "f_B") + def randomize_configuations_inplace(self, quizzes, configurations): + r = torch.randint( + len(configurations), (quizzes.size(0),), device=quizzes.device ) + for c in range(len(configurations)): + quizzes[r == c] = self.problem.reconfigure( + quizzes[r == c], struct=configurations[c] + ) + def create_w_quizzes(self, model, nb_train_samples, nb_test_samples): model.train_w_quizzes = self.problem.generate_w_quizzes(nb_train_samples) model.test_w_quizzes = self.problem.generate_w_quizzes(nb_test_samples) - self.flip_half_in_place(model.train_w_quizzes) - self.flip_half_in_place(model.test_w_quizzes) + self.randomize_configuations_inplace( + model.train_w_quizzes, configurations=self.configurations + ) + + self.randomize_configuations_inplace( + model.test_w_quizzes, configurations=self.configurations + ) + + # print(model.train_w_quizzes.sum()) ###################################################################### @@ -287,7 +297,9 @@ class QuizMachine: model.train_w_quizzes.size(0) ) - self.flip_half_in_place(model.train_w_quizzes) + self.randomize_configuations_inplace( + model.train_w_quizzes, configurations=self.configurations + ) ###################################################################### -- 2.39.5