From 6bd776c5842485db888d81e756e22623e8dc949f Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 21 Jun 2024 08:29:04 +0200 Subject: [PATCH] Update. --- tasks.py | 48 ++++++++++++++---------------------- world.py | 75 ++++++++++++++++++++++++++++---------------------------- 2 files changed, 55 insertions(+), 68 deletions(-) diff --git a/tasks.py b/tasks.py index 7894fcd..b4e6f67 100755 --- a/tasks.py +++ b/tasks.py @@ -500,38 +500,26 @@ class World(Task): logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}") - if save_attention_image is not None: - for k in range(10): - ns = torch.randint(self.test_input.size(0), (1,)).item() - input = self.test_input[ns : ns + 1].clone() + ############################## - with torch.autograd.no_grad(): - t = model.training - model.eval() - # model.record_attention(True) - model(BracketedSequence(input)) - model.train(t) - # ram = model.retrieve_attention() - # model.record_attention(False) + input, ar_mask = self.test_input[:64], self.test_ar_mask[:64] + result = input.clone() * (1 - ar_mask) - # tokens_output = [c for c in self.problem.seq2str(input[0])] - # tokens_input = ["n/a"] + tokens_output[:-1] - # for n_head in range(ram[0].size(1)): - # filename = os.path.join( - # result_dir, f"sandbox_attention_{k}_h{n_head}.pdf" - # ) - # attention_matrices = [m[0, n_head] for m in ram] - # save_attention_image( - # filename, - # tokens_input, - # tokens_output, - # attention_matrices, - # k_top=10, - ##min_total_attention=0.9, - # token_gap=12, - # layer_gap=50, - # ) - # logger(f"wrote {filename}") + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + progress_bar_desc=None, + device=self.device, + ) + + img = world.sample2img(result.to("cpu"), self.height, self.width) + + image_name = os.path.join(result_dir, f"world_result_{n_epoch:04d}.png") + torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2) + logger(f"wrote {image_name}") ###################################################################### diff --git a/world.py b/world.py index 0392940..ac201e7 100755 --- a/world.py +++ b/world.py @@ -27,76 +27,75 @@ colors = torch.tensor( ] ) -token2char = "_X01234>" +token2char = "_X" + "".join([str(n) for n in range(len(colors) - 2)]) + ">" def generate( nb, height, width, - obj_length=6, - mask_height=3, - mask_width=3, - nb_obj=3, + max_nb_obj=len(colors) - 2, + nb_iterations=2, ): - intact = torch.zeros(nb, height, width, dtype=torch.int64) - n = torch.arange(intact.size(0)) + f_start = torch.zeros(nb, height, width, dtype=torch.int64) + f_end = torch.zeros(nb, height, width, dtype=torch.int64) + n = torch.arange(f_start.size(0)) for n in range(nb): - for c in torch.randperm(colors.size(0) - 2)[:nb_obj] + 2: - z = intact[n].flatten() - m = (torch.rand(z.size()) * (z == 0)).argmax(dim=0) - i, j = m // width, m % width + nb_fish = torch.randint(max_nb_obj, (1,)).item() + 1 + for c in range(nb_fish): + i, j = ( + torch.randint(height - 2, (1,))[0] + 1, + torch.randint(width - 2, (1,))[0] + 1, + ) vm = torch.randint(4, (1,))[0] vi, vj = (vm // 2) * (2 * (vm % 2) - 1), (1 - vm // 2) * (2 * (vm % 2) - 1) - for l in range(obj_length): - intact[n, i, j] = c + + f_start[n, i, j] = c + 2 + f_start[n, i - vi, j - vj] = c + 2 + f_start[n, i + vj, j - vi] = c + 2 + f_start[n, i - vj, j + vi] = c + 2 + + for l in range(nb_iterations): i += vi j += vj - if i < 0 or i >= height or j < 0 or j >= width or intact[n, i, j] != 0: + if i < 0 or i >= height or j < 0 or j >= width: i -= vi j -= vj - vi, vj = -vj, vi + vi, vj = -vi, -vj i += vi j += vj - if ( - i < 0 - or i >= height - or j < 0 - or j >= width - or intact[n, i, j] != 0 - ): - break - masked = intact.clone() - - for n in range(nb): - i = torch.randint(height - mask_height + 1, (1,))[0] - j = torch.randint(width - mask_width + 1, (1,))[0] - masked[n, i : i + mask_height, j : j + mask_width] = 1 + f_end[n, i, j] = c + 2 + f_end[n, i - vi, j - vj] = c + 2 + f_end[n, i + vj, j - vi] = c + 2 + f_end[n, i - vj, j + vi] = c + 2 return torch.cat( [ - masked.flatten(1), - torch.full((masked.size(0), 1), len(colors)), - intact.flatten(1), + f_end.flatten(1), + torch.full((f_end.size(0), 1), len(colors)), + f_start.flatten(1), ], dim=1, ) def sample2img(seq, height, width): - intact = seq[:, : height * width].reshape(-1, height, width) - masked = seq[:, height * width + 1 :].reshape(-1, height, width) - img_intact, img_masked = colors[intact], colors[masked] + f_start = seq[:, : height * width].reshape(-1, height, width) + f_start = (f_start >= len(colors)).long() + (f_start < len(colors)).long() * f_start + f_end = seq[:, height * width + 1 :].reshape(-1, height, width) + f_end = (f_end >= len(colors)).long() + (f_end < len(colors)).long() * f_end + + img_f_start, img_f_end = colors[f_start], colors[f_end] img = torch.cat( [ - img_intact, + img_f_start, torch.full( - (img_intact.size(0), img_intact.size(1), 1, img_intact.size(3)), 1 + (img_f_start.size(0), img_f_start.size(1), 1, img_f_start.size(3)), 1 ), - img_masked, + img_f_end, ], dim=2, ) -- 2.39.5