From 8741b753b05bb8b89a65d43af6d7637771979d25 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 24 Jul 2024 11:31:48 +0200 Subject: [PATCH] Update. --- grids.py | 376 +++++++++++++++++------------------------------- problem.py | 8 +- quiz_machine.py | 133 ----------------- 3 files changed, 137 insertions(+), 380 deletions(-) diff --git a/grids.py b/grids.py index e64cb33..131f85c 100755 --- a/grids.py +++ b/grids.py @@ -118,80 +118,49 @@ class Grids(problem.Problem): ("gray", [128, 128, 128]), ] - def make_ar_mask(self, quizzes, shape="fwd_3_bck_123"): + def check_structure(self, quizzes, struct): S = self.height * self.width - assert ( - ( - (quizzes[:, 0] == self.token_forward) - | (quizzes[:, 0] == self.token_backward) - ) - & (quizzes[:, 0] == quizzes[:, 1 * (S + 1)]) - & (quizzes[:, 0] == quizzes[:, 2 * (S + 1)]) - & (quizzes[:, 0] == quizzes[:, 3 * (S + 1)]) + return ( + (quizzes[:, 0 * (S + 1)] == self.l2tok(struct[0])) + & (quizzes[:, 1 * (S + 1)] == self.l2tok(struct[1])) + & (quizzes[:, 2 * (S + 1)] == self.l2tok(struct[2])) + & (quizzes[:, 3 * (S + 1)] == self.l2tok(struct[3])) ).all() - T = torch.arange(quizzes.size(1), device=quizzes.device) - - if shape == "fwd_3_bck_123": - forward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long() - backward_mask = ((T % (S + 1) != 0) & (T >= 1 * (S + 1))).long() - elif shape == "fwd_012_bck_0": - forward_mask = ((T % (S + 1) != 0) & (T < 3 * (S + 1))).long() - backward_mask = ((T % (S + 1) != 0) & (T < 1 * (S + 1))).long() - elif shape == "fwd_3_bck_3": - forward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long() - backward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long() - else: - raise ValueError(shape) - - is_forward = (quizzes[:, 0] == self.token_forward).long() + def make_ar_mask(self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)): + assert check_structure(quizzes, struct) - return ( - is_forward[:, None] * forward_mask[None, :] - + (1 - is_forward)[:, None] * backward_mask[None, :] - ) + ar_mask = quizzes.new_zeros(quizzes.size()) - def p_a_flip(self, quizzes, pairwise_flip=False): - S = self.height * self.width + a = ar_mask.reshape(-1, 4, -1)[:, :, 1:] + a[:, 0, :] = mask[0] + a[:, 1, :] = mask[1] + a[:, 2, :] = mask[2] + a[:, 3, :] = mask[3] - assert ( - ( - (quizzes[:, 0] == self.token_forward) - | (quizzes[:, 0] == self.token_backward) - ) - & (quizzes[:, 0] == quizzes[:, 1 * (S + 1)]) - & (quizzes[:, 0] == quizzes[:, 2 * (S + 1)]) - & (quizzes[:, 0] == quizzes[:, 3 * (S + 1)]) - ).all() + return ar_mask - if pairwise_flip: - flipped = torch.cat( - [ - quizzes[:, 1 * (S + 1) : 1 * (S + 1) + S + 1], - quizzes[:, 0 * (S + 1) : 0 * (S + 1) + S + 1], - quizzes[:, 3 * (S + 1) : 3 * (S + 1) + S + 1], - quizzes[:, 2 * (S + 1) : 2 * (S + 1) + S + 1], - ], - dim=1, - ) - else: - flipped_from_forward = torch.cat( - [quizzes[:, 3 * (S + 1) :], quizzes[:, : 3 * (S + 1)]], - dim=1, - ) - flipped_from_forward[:, torch.arange(4) * (S + 1)] = self.token_backward + def reconfigure( + self, + quizzes, + struct_from=("A", "f_A", "B", "f_B"), + struct_to=("f_B", "A", "f_A", "B"), + ): + assert check_structure(quizzes, struct_from) - flipped_from_backward = torch.cat( - [quizzes[:, S + 1 :], quizzes[:, : S + 1]], dim=1 - ) - flipped_from_backward[:, torch.arange(4) * (S + 1)] = self.token_forward + sf = dict((l, n) for n, l in enumerate(struct_from)) - m = (quizzes[:, 0] == self.token_forward).long()[:, None] + result = quizzes.new(quizzes.size()) + q = quizzes.reshape(-1, 4, 4 * (S + 1)) + r = reshape.reshape(-1, 4, 4 * (S + 1)) - flipped = m * flipped_from_forward + (1 - m) * flipped_from_backward + r[:, 0, :] = q[:, sf[struct_to[0]]] + r[:, 1, :] = q[:, sf[struct_to[1]]] + r[:, 2, :] = q[:, sf[struct_to[2]]] + r[:, 3, :] = q[:, sf[struct_to[3]]] - return flipped + return result def __init__( self, @@ -201,8 +170,20 @@ class Grids(problem.Problem): tasks=None, ): self.colors = torch.tensor([c for _, c in self.named_colors]) - self.token_forward = len(self.colors) - self.token_backward = self.token_forward + 1 + + self.token_A = len(self.colors) + self.token_f_A = self.token_A + 1 + self.token_B = self.token_f_A + 1 + self.token_f_B = self.token_B + 1 + self.l2tok = { + "A": self.token_A, + "f_A": self.token_f_A, + "B": self.token_B, + "f_B": self.token_f_B, + } + + self.nb_token_values = self.token_f_B + 1 + self.height = 10 self.width = 10 self.cache_rec_coo = {} @@ -237,8 +218,7 @@ class Grids(problem.Problem): ###################################################################### - def frame2img(self, x, scale=15): - x = x.reshape(x.size(0), self.height, self.width) + def grid2img(self, x, scale=15): m = torch.logical_and(x >= 0, x < len(self.colors)).long() y = self.colors[x * m].permute(0, 3, 1, 2) s = y.shape @@ -247,154 +227,95 @@ class Grids(problem.Problem): y[:, :, :, torch.arange(0, y.size(3), scale)] = 0 y[:, :, torch.arange(0, y.size(2), scale), :] = 0 - y = y[:, :, 1:, 1:] for n in range(m.size(0)): for i in range(m.size(1)): for j in range(m.size(2)): - if x[n, i, j] == self.token_forward: - for k in range(2, scale - 2): - y[ - n, - :, - i * scale + k, - j * scale + scale - 5 - abs(k - scale // 2), - ] = 0 - - elif x[n, i, j] == self.token_backward: - for k in range(2, scale - 2): - y[ - n, :, i * scale + k, j * scale + 3 + abs(k - scale // 2) - ] = 0 - # y[n, :, i * scale + k, j * scale + k - l] = 0 - # y[ - # n, :, i * scale + scale - 1 - k, j * scale + k - l - # ] = 0 + if m[n, i, j] == 0: + for k in range(3, scale - 2): + y[n, :, i * scale + k, j * scale + k] = 0 + y[n, :, i * scale + k, j * scale + scale - k] = 0 + + y = y[:, :, 1:, 1:] return y - def save_image( + def add_frame(self, img, colors, thickness): + result = img.new( + img.size(0), + img.size(1), + img.size(2) + 2 * thickness, + img.size(3) + 2 * thickness, + ) + + result[...] = colors[:, :, None, None] + result[:, :, thickness:-thickness, thickness:-thickness] = img + + return result + + def save_quizzes_as_image( self, result_dir, filename, - prompts, - answers, - predicted_prompts=None, - predicted_answers=None, + quizzes, + predicted_parts=None, + correct_parts=None, nrow=4, margin=8, ): S = self.height * self.width - As = prompts[:, 0 * (S + 1) + 1 : 0 * (S + 1) + S + 1].view( - -1, self.height, self.width - ) - f_As = prompts[:, 1 * (S + 1) + 1 : 1 * (S + 1) + S + 1].view( - -1, self.height, self.width - ) - Bs = prompts[:, 2 * (S + 1) + 1 : 2 * (S + 1) + S + 1].view( - -1, self.height, self.width + + A, f_A, B, f_B = ( + quizzes.reshape(-1, 4, S + 1)[:, :, 1:] + .reshape(-1, 4, self.height, self.width) + .permute(1, 0, 2, 3) ) - prompts = torch.cat([As, f_As, Bs], dim=2) - answers = answers[:, 1 : S + 1].reshape( - answers.size(0), self.height, self.width + + black, white, gray, green, red = torch.tensor( + [[0, 0, 0], [255, 255, 255], [200, 200, 200], [0, 255, 0], [255, 0, 0]], + device=quizzes.device, ) - if predicted_prompts is None: - predicted_prompts = 255 + img_A = self.add_frame(self.grid2img(A), black[None, :], thickness=1) + img_f_A = self.add_frame(self.grid2img(f_A), black[None, :], thickness=1) + img_B = self.add_frame(self.grid2img(B), black[None, :], thickness=1) + img_f_B = self.add_frame(self.grid2img(f_B), black[None, :], thickness=1) - if predicted_answers is None: - predicted_answers = 255 + # predicted_parts Nx4 + # correct_parts Nx4 - def add_frame(x, c, margin, bottom=False): - if bottom: - h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0 - else: - h, w, di, dj = ( - x.size(2) + 2 * margin, - x.size(3) + 2 * margin, - margin, - margin, + if predicted_parts is None: + colors = white[None, None, :].expand(-1, 4, -1) + else: + if correct_parts is None: + colors = ( + predicted_parts[:, :, None] * gray[None, None, :] + + (1 - predicted_parts[:, :, None]) * white[None, None, :] ) - - y = x.new_full((x.size(0), x.size(1), h, w), 0) - - if type(c) is int: - y[...] = c else: - c = c.long()[:, None] - c = ( - (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long())) - * torch.tensor([64, 64, 64]) - + (c == 1).long() * torch.tensor([0, 255, 0]) - + (c == 0).long() * torch.tensor([255, 255, 255]) - + (c == -1).long() * torch.tensor([255, 0, 0]) - ) - y[...] = c[:, :, None, None] - - y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x - - return y - - img_prompts = torch.cat( - [ - add_frame( - add_frame(self.frame2img(x), c=0, margin=1), - c=predicted_prompts, - margin=margin, + colors = ( + predicted_parts[:, :, None] + * ( + correct_parts[:, :, None] * green[None, None, :] + + (1 - correct_parts[:, :, None]) * red[None, None, :] + ) + + (1 - predicted_parts[:, :, None]) * white[None, None, :] ) - for x in prompts.to("cpu").split(split_size=self.width, dim=2) - ], - dim=3, - ) - - h = img_prompts.size(2) - img_answers = add_frame( - add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1), - c=predicted_answers, - margin=margin, - ) - - separator_size = 2 * margin - - separator = img_prompts.new_full( - ( - img_prompts.size(0), - img_prompts.size(1), - img_prompts.size(2), - separator_size, - ), - 255, - ) - - marker = img_prompts.new_full( - ( - img_prompts.size(0), - img_prompts.size(1), - img_prompts.size(2), - separator_size, - ), - 255, - ) - # marker[:, :, 0] = 0 - # marker[:, :, h - 1] = 0 + img_A = self.add_frame(img_A, colors[:, 0], thickness=6) + img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=6) + img_B = self.add_frame(img_B, colors[:, 2], thickness=6) + img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=6) - for k in range(1, 2 * separator_size - 8): - i = k - (separator_size - 4) - j = separator_size - 5 - abs(i) - marker[:, :, h // 2 - 1 + i, 2 + j] = 0 - marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0 + img_A = self.add_frame(img_A, white[None, :], thickness=2) + img_f_A = self.add_frame(img_f_A, white[None, :], thickness=2) + img_B = self.add_frame(img_B, white[None, :], thickness=2) + img_f_B = self.add_frame(img_f_B, white[None, :], thickness=2) - img = torch.cat( - [ - img_prompts, - marker, - img_answers, - ], - dim=3, - ) + img = torch.cat([img_A, img_f_A, img_B, img_f_B], dim=3) image_name = os.path.join(result_dir, filename) + torchvision.utils.save_image( img.float() / 255.0, image_name, @@ -405,9 +326,6 @@ class Grids(problem.Problem): ###################################################################### - def nb_token_values(self): - return len(self.colors) + 2 - # @torch.compile def rec_coo( self, @@ -1444,70 +1362,36 @@ class Grids(problem.Problem): f_Bs = answers[:, 1:] return (Bs == f_Bs).long().min(dim=-1).values > 0 - def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False): + def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False): if tasks is None: tasks = self.all_tasks S = self.height * self.width - prompts = torch.full((nb, 3 * S + 3), self.token_forward) - answers = torch.full((nb, S + 1), self.token_forward) - - bunch = zip(prompts, answers) + quizzes = torch.empty(nb, 4 * (S + 1), dtype=torch.int64) if progress_bar: - bunch = tqdm.tqdm( - bunch, + quizzes = tqdm.tqdm( + quizzes, dynamic_ncols=True, - desc="world generation", + desc="world quizzes generation", total=prompts.size(0), ) - for prompt, answer in bunch: - A = prompt[0 * (S + 1) + 1 : 0 * (S + 1) + 1 + S].view( - self.height, self.width - ) - f_A = prompt[1 * (S + 1) + 1 : 1 * (S + 1) + 1 + S].view( - self.height, self.width - ) - B = prompt[2 * (S + 1) + 1 : 2 * (S + 1) + S + 1].view( - self.height, self.width - ) - f_B = answer[1 : S + 1].view(self.height, self.width) + for quiz in quizzes: + q = quiz.reshape(4, S + 1)[:, 1:].reshape(4, self.height, self.width) + q[...] = 0 + A, f_A, B, f_B = q task = tasks[torch.randint(len(tasks), (1,)).item()] - A[...] = 0 - f_A[...] = 0 - B[...] = 0 - f_B[...] = 0 task(A, f_A, B, f_B) - return prompts.flatten(1), answers.flatten(1) - - def save_quiz_illustrations( - self, - result_dir, - filename_prefix, - prompts, - answers, - predicted_prompts=None, - predicted_answers=None, - nrow=4, - ): - self.save_image( - result_dir, - filename_prefix + ".png", - prompts, - answers, - predicted_prompts, - predicted_answers, - nrow, - ) + return quizzes def save_some_examples(self, result_dir): nb, nrow = 128, 4 for t in self.all_tasks: print(t.__name__) - prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t]) - self.save_quiz_illustrations( + prompts, answers = self.generate_w_quizzes_(nb, tasks=[t]) + self.save_quizzes_as_image( result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow ) @@ -1526,7 +1410,7 @@ if __name__ == "__main__": # ) # time.sleep(10) # start_time = time.perf_counter() - # prompts, answers = grids.generate_prompts_and_answers(nb) + # prompts, answers = grids.generate_w_quizzes(nb) # delay = time.perf_counter() - start_time # print(f"{prompts.size(0)/delay:02f} seq/s") # exit(0) @@ -1536,13 +1420,19 @@ if __name__ == "__main__": # nb, nrow = 8, 2 # for t in grids.all_tasks: - for t in [grids.task_reconfigure]: + for t in [grids.task_replace_color]: # for t in [grids.task_symbols]: print(t.__name__) - prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) - # prompts[...] = torch.randint(grids.nb_token_values(), prompts.size()) - grids.save_quiz_illustrations( - "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow + quizzes = grids.generate_w_quizzes_(nb, tasks=[t]) + predicted_parts = quizzes.new_zeros(quizzes.size(0), 4) + predicted_parts[:, 3] = 1 + correct_parts = torch.randint(2, (quizzes.size(0), 4), device=quizzes.device) + grids.save_quizzes_as_image( + "/tmp", + t.__name__ + ".png", + quizzes, + predicted_parts=predicted_parts, + correct_parts=correct_parts, ) exit(0) @@ -1552,7 +1442,7 @@ if __name__ == "__main__": for t in grids.all_tasks: # for t in [grids.task_compute]: start_time = time.perf_counter() - prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) + prompts, answers = grids.generate_w_quizzes_(nb, tasks=[t]) delay = time.perf_counter() - start_time print(f"{t.__name__} {prompts.size(0)/delay:02f} seq/s") diff --git a/problem.py b/problem.py index 05f3b20..61e4834 100755 --- a/problem.py +++ b/problem.py @@ -32,7 +32,7 @@ class Problem: pass # The one to implement, returns two tensors nb x D and nb x D' - def generate_prompts_and_answers_(self, nb): + def generate_w_quizzes_(self, nb): pass # save a file to vizualize quizzes, you can save a txt or png file @@ -49,13 +49,13 @@ class Problem: def fill_cache(self): while True: - prompts, answers = self.generate_prompts_and_answers_(self.chunk_size) + prompts, answers = self.generate_w_quizzes_(self.chunk_size) self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True) - def generate_prompts_and_answers(self, nb): + def generate_w_quizzes(self, nb): if self.queue is None: - return self.generate_prompts_and_answers_(nb) + return self.generate_w_quizzes_(nb) if self.rest is not None: prompts, answers = rest diff --git a/quiz_machine.py b/quiz_machine.py index e70b903..d62ba3b 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -599,136 +599,3 @@ class QuizMachine: return c_quizzes.to("cpu") ###################################################################### - - def generate_c_quizzes_mixing( - self, - nb, - model_for_generation, - p2a_only=False, - temperature_hot=1.0, - temperature_cold=1.0, - ): - c_quizzes = torch.empty( - nb, - self.prompt_len + self.answer_len, - device=self.device, - dtype=torch.int64, - ) - - c_quizzes_1 = torch.empty( - nb, - self.prompt_len + self.answer_len, - device=self.device, - dtype=torch.int64, - ) - - c_quizzes_2 = torch.empty( - nb, - self.prompt_len + self.answer_len, - device=self.device, - dtype=torch.int64, - ) - - seq_logproba = torch.zeros(nb, device=self.device) - - lt_noisy = lambda s, logits: logits / temperature_hot - lt_clean = lambda s, logits: logits / temperature_cold - - ###################################################################### - - c_quizzes_1[...] = self.problem.token_backward - ar_mask = self.problem.make_ar_mask(c_quizzes_1, shape="fwd_012_bck_0") - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes_1, - ar_mask=ar_mask, - seq_logproba=seq_logproba, - logit_transformer=lt_noisy, - deterministic_synthesis=False, - device=self.device, - ) - - self.save_quiz_illustrations("/tmp", f"c_quizzes_1", c_quizzes_1) - - c_quizzes_2[...] = self.problem.token_backward - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes_2, - ar_mask=ar_mask, - seq_logproba=seq_logproba, - logit_transformer=lt_noisy, - deterministic_synthesis=False, - device=self.device, - ) - - self.save_quiz_illustrations("/tmp", f"c_quizzes_2", c_quizzes_2) - - h = len(model_for_generation.trunk) // 2 - - with torch.autograd.no_grad(): - t = model_for_generation.training - model_for_generation.eval() - - bs1 = model_for_generation.partial_forward( - mygpt.BracketedSequence(c_quizzes_1), end_layer=h - ) - bs2 = model_for_generation.partial_forward( - mygpt.BracketedSequence(c_quizzes_2), end_layer=h - ) - - alpha = 0.5 - - output = model_for_generation.partial_forward( - mygpt.BracketedSequence(alpha * bs1.x + (1 - alpha) * bs2.x), - start_layer=h, - ).x - - dist = torch.distributions.categorical.Categorical(logits=output) - c_quizzes[...] = dist.sample() - - c_quizzes[...] = ( - ar_mask * c_quizzes + (1 - ar_mask) * self.problem.token_backward - ) - - model_for_generation.train(t) - - self.save_quiz_illustrations("/tmp", f"c_quizzes", c_quizzes) - - ###################################################################### - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), - seq_logproba=seq_logproba, - logit_transformer=lt_clean, - deterministic_synthesis=False, - device=self.device, - ) - - self.save_quiz_illustrations("/tmp", f"c_quizzes_A", c_quizzes) - - c_quizzes = self.problem.p_a_flip(c_quizzes) - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), - seq_logproba=seq_logproba, - logit_transformer=lt_clean, - deterministic_synthesis=False, - device=self.device, - ) - - self.save_quiz_illustrations("/tmp", f"c_quizzes_B", c_quizzes) - - print("DONE") - exit(0) - - return c_quizzes.to("cpu") -- 2.39.5