From 813c9a661e188c18fa3598b876f68b269df94569 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 19 Sep 2024 20:15:48 +0200 Subject: [PATCH] Update. --- grids.py | 193 ++++++++++++++----------------------------------------- main.py | 15 +---- 2 files changed, 52 insertions(+), 156 deletions(-) diff --git a/grids.py b/grids.py index 5e623cb..fb31c7d 100755 --- a/grids.py +++ b/grids.py @@ -134,17 +134,20 @@ def grow_islands(nb, height, width, nb_seeds, nb_iterations): class Grids(problem.Problem): - grid_gray = 64 - thickness = 1 - background_gray = 255 + # grid_gray = 64 + # thickness = 1 + # background_gray = 255 + # dots = False # grid_gray=240 # thickness=1 # background_gray=240 + # dots = False - # grid_gray = 255 - # thickness = 0 - # background_gray = 240 + grid_gray = 200 + thickness = 0 + background_gray = 240 + dots = True named_colors = [ ("white", [background_gray, background_gray, background_gray]), @@ -161,76 +164,10 @@ class Grids(problem.Problem): ("gray", [128, 128, 128]), ] - def check_order(self, quizzes, quad_order): - S = self.height * self.width - - return ( - (quizzes[:, 0 * (S + 1)] == self.l2tok[quad_order[0]]) - & (quizzes[:, 1 * (S + 1)] == self.l2tok[quad_order[1]]) - & (quizzes[:, 2 * (S + 1)] == self.l2tok[quad_order[2]]) - & (quizzes[:, 3 * (S + 1)] == self.l2tok[quad_order[3]]) - ).all() - - def get_order(self, quizzes): - S = self.height * self.width - quad_order = tuple( - self.tok2l[n.item()] - for n in quizzes.reshape(quizzes.size(0), 4, S + 1)[0, :, 0] - ) - self.check_order(quizzes, quad_order) - return quad_order - - def inject_noise(self, quizzes, noise, quad_order, quad_noise): - assert self.check_order(quizzes, quad_order=quad_order) - S = self.height * self.width - - mask = torch.tensor(quad_noise, device=quizzes.device) - mask = mask[None, :, None].expand(1, 4, S + 1).clone() - mask[:, :, 0] = 0 - mask = mask.reshape(1, -1).expand_as(quizzes) - mask = mask * (torch.rand(mask.size(), device=mask.device) <= noise).long() - random = torch.randint(self.nb_colors, mask.size()) - quizzes = mask * random + (1 - mask) * quizzes - - return quizzes - def pure_noise(self, nb, device): result = torch.randint( - self.nb_colors, (nb, 4 * (self.height * self.height + 1)), device=device + self.nb_colors, (nb, 4 * (self.height * self.height)), device=device ) - result.view(nb, 4, -1)[:, 0, 0] = self.token_A - result.view(nb, 4, -1)[:, 1, 0] = self.token_f_A - result.view(nb, 4, -1)[:, 2, 0] = self.token_B - result.view(nb, 4, -1)[:, 3, 0] = self.token_f_B - return result - - # What a mess - def reconfigure(self, quizzes, quad_order=("A", "f_A", "B", "f_B")): - if torch.is_tensor(quizzes): - return self.reconfigure([quizzes], quad_order=quad_order)[0] - - S = self.height * self.width - result = [x.new(x.size()) for x in quizzes] - - quad_order_from = self.get_order(quizzes[0][:1]) - i = self.indices_select(quizzes[0], quad_order_from) - - sf = dict((l, n) for n, l in enumerate(quad_order_from)) - - for q in range(4): - k = sf[quad_order[q]] - for x, y in zip(quizzes, result): - l = x.size(1) // 4 - y[i, q * l : (q + 1) * l] = x[i, k * l : (k + 1) * l] - - j = i == False - - if j.any(): - for z, y in zip( - self.reconfigure([x[j] for x in quizzes], quad_order=quad_order), result - ): - y[j] = z - return result def trivial(self, quizzes): @@ -241,22 +178,6 @@ class Grids(problem.Problem): dim=1 ).values - def make_quiz_mask( - self, quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=(0, 0, 0, 1) - ): - assert self.check_order(quizzes, quad_order) - - ar_mask = quizzes.new_zeros(quizzes.size()) - - S = self.height * self.width - a = ar_mask.reshape(ar_mask.size(0), 4, S + 1)[:, :, 1:] - a[:, 0, :] = quad_mask[0] - a[:, 1, :] = quad_mask[1] - a[:, 2, :] = quad_mask[2] - a[:, 3, :] = quad_mask[3] - - return ar_mask - def text2quiz(self, t): chr2col = [ (".", "white"), @@ -322,32 +243,13 @@ class Grids(problem.Problem): self.colors = torch.tensor([c for _, c in self.named_colors]) self.nb_colors = len(self.colors) - self.token_A = self.nb_colors - self.token_f_A = self.token_A + 1 - self.token_B = self.token_f_A + 1 - self.token_f_B = self.token_B + 1 self.nb_rec_max = 5 self.rfree = torch.tensor([]) - self.l2tok = { - "A": self.token_A, - "f_A": self.token_f_A, - "B": self.token_B, - "f_B": self.token_f_B, - } - - self.tok2l = { - self.token_A: "A", - self.token_f_A: "f_A", - self.token_B: "B", - self.token_f_B: "f_B", - } - self.height = 10 self.width = 10 - self.seq_len = 4 * (1 + self.height * self.width) - self.nb_token_values = self.token_f_B + 1 + self.seq_len = 4 * self.height * self.width self.cache_rec_coo = {} @@ -385,7 +287,7 @@ class Grids(problem.Problem): ###################################################################### def vocabulary_size(self): - return self.nb_token_values + return self.nb_colors def grid2img(self, x, scale=15, grids=True): m = torch.logical_and(x >= 0, x < self.nb_colors).long() @@ -398,14 +300,41 @@ class Grids(problem.Problem): for t in range(self.thickness): y[:, :, :, torch.arange(t, y.size(3), scale)] = self.grid_gray y[:, :, torch.arange(t, y.size(2), scale), :] = self.grid_gray + if self.dots: + z = y.reshape( + y.size(0), + y.size(1), + y.size(2) // scale, + scale, + y.size(3) // scale, + scale, + ) + z = z[ + :, + :, + :, + scale // 2 - 2 : scale // 2 + 1, + :, + scale // 2 - 2 : scale // 2 + 1, + ] + z[...] = (z == self.background_gray) * self.grid_gray + ( + z != self.background_gray + ) * z for n in range(m.size(0)): for i in range(m.size(1)): for j in range(m.size(2)): - if m[n, i, j] == 0: - for k in range(3, scale - 2): - y[n, :, i * scale + k, j * scale + k] = 0 - y[n, :, i * scale + k, j * scale + scale - k] = 0 + if x[n, i, j] >= self.nb_colors: + # for k in range(3, scale - 2): + c = self.colors[x[n, i, j] - self.nb_colors][:, None, None] + # y[n, :, i * scale + k, j * scale + k] = c + # y[n, :, i * scale + k, j * scale + scale - k] = c + y[ + n, + :, + i * scale + 3 : i * scale + scale - 2, + j * scale + 3 : j * scale + scale - 2, + ] = c y = y[:, :, 1:, 1:] @@ -443,37 +372,10 @@ class Grids(problem.Problem): ): quizzes = quizzes.to("cpu") - if quizzes.size(1) == 4 * self.height * self.width: - quizzes = torch.cat( - [ - quizzes.new_zeros(quizzes.size(0), 4, 1), - quizzes.reshape(quizzes.size(0), 4, -1), - ], - dim=2, - ) - quizzes[:, :, 0] = torch.tensor( - [self.token_A, self.token_f_A, self.token_B, self.token_f_B] - )[None, :] - quizzes = quizzes.reshape(quizzes.size(0), -1) - - to_reconfigure = [quizzes] - if predicted_parts is not None: - to_reconfigure.append(predicted_parts) - if correct_parts is not None: - to_reconfigure.append(correct_parts) - - to_reconfigure = self.reconfigure(to_reconfigure, ("A", "f_A", "B", "f_B")) - - quizzes = to_reconfigure.pop(0) - if predicted_parts is not None: - predicted_parts = to_reconfigure.pop(0) - if correct_parts is not None: - correct_parts = to_reconfigure.pop(0) - S = self.height * self.width A, f_A, B, f_B = ( - quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:] + quizzes.reshape(quizzes.size(0), 4, S) .reshape(quizzes.size(0), 4, self.height, self.width) .permute(1, 0, 2, 3) ) @@ -1881,7 +1783,7 @@ class Grids(problem.Problem): if tasks is None: tasks = self.all_tasks - quizzes = self.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B")) + quizzes = torch.empty(nb, 4 * self.height * self.width, dtype=torch.int64) if progress_bar: quizzes = tqdm.tqdm( @@ -1892,7 +1794,7 @@ class Grids(problem.Problem): ) for quiz in quizzes: - q = quiz.reshape(4, S + 1)[:, 1:].reshape(4, self.height, self.width) + q = quiz.reshape(4, self.height, self.width) q[...] = 0 A, f_A, B, f_B = q task = tasks[torch.randint(len(tasks), (1,)).item()] @@ -1932,6 +1834,9 @@ if __name__ == "__main__": ]: print(t.__name__) w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t]) + + w_quizzes[:5] = torch.randint(grids.vocabulary_size(), w_quizzes[:5].size()) + grids.save_quizzes_as_image( "/tmp", t.__name__ + ".png", diff --git a/main.py b/main.py index 85f2cb6..5493b7d 100755 --- a/main.py +++ b/main.py @@ -250,18 +250,9 @@ assert args.nb_test_samples % args.batch_size == 0 ###################################################################### -def pure_noise(nb, device): - r = problem.pure_noise(nb, device) - r = r.view(r.size(0), 4, -1)[:, :, 1:].reshape(r.size(0), -1) - return r - - def quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1): if c_quizzes is None: quizzes = problem.generate_w_quizzes(nb_samples) - quizzes = quizzes.view(quizzes.size(0), 4, -1)[:, :, 1:].reshape( - quizzes.size(0), -1 - ) nb_w_quizzes = quizzes.size(0) nb_c_quizzes = 0 else: @@ -340,7 +331,7 @@ def add_noise_imt(imt_set): """Replace every component of the input by a random value with probability args.proba_prompt_noise.""" input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2] - noise = pure_noise(input.size(0), input.device) + noise = problem.pure_noise(input.size(0), input.device) change = (1 - masks) * ( torch.rand(input.size(), device=input.device) < args.proba_prompt_noise ).long() @@ -428,7 +419,7 @@ def samples_for_generation_imt(input): proba_erased = 1 - (1 - args.diffusion_proba_corruption) ** t mask_erased = (r <= proba_erased[:, None]).long() - noise = pure_noise(nb, input.device) + noise = problem.pure_noise(nb, input.device) targets = input input = (1 - mask_erased) * input + mask_erased * noise masks = input.new_full(input.size(), 1) @@ -452,7 +443,7 @@ def ae_generate(model, nb, local_device=main_device): # mini-batches second so that we keep only the samples that have # not stabilized - all_input = pure_noise(nb, local_device) + all_input = problem.pure_noise(nb, local_device) all_masks = all_input.new_full(all_input.size(), 1) all_changed = torch.full((all_input.size(0),), True, device=all_input.device) -- 2.39.5