From d44d0605fed828b8cea08c8e1c5bda7e4528ea97 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 11 Mar 2023 17:45:02 +0100 Subject: [PATCH] Update. --- beaver.py | 354 +++++++++++------------------------------------------- maze.py | 47 +++++--- 2 files changed, 104 insertions(+), 297 deletions(-) diff --git a/beaver.py b/beaver.py index b0fa03c..4d4f98d 100755 --- a/beaver.py +++ b/beaver.py @@ -71,11 +71,11 @@ parser.add_argument("--overwrite_results", action="store_true", default=False) parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") ############################## -# picoclvr options +# maze options -parser.add_argument("--world_height", type=int, default=23) +parser.add_argument("--world_height", type=int, default=13) -parser.add_argument("--world_width", type=int, default=31) +parser.add_argument("--world_width", type=int, default=21) parser.add_argument("--world_nb_walls", type=int, default=15) @@ -83,8 +83,6 @@ parser.add_argument("--world_nb_walls", type=int, default=15) args = parser.parse_args() -assert args.prune_properties in {"none", "train+eval", "eval"} - try: os.mkdir(args.result_dir) except FileExistsError: @@ -122,9 +120,11 @@ for n in vars(args): ###################################################################### -def masked_inplace_autoregression( - model, batch_size, input, ar_mask, forbidden_tokens=None, device=torch.device("cpu") -): +# ar_mask is a Boolean matrix of same shape as input, with 1s on the +# tokens that should be generated + + +def masked_inplace_autoregression(model, batch_size, input, ar_mask): for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)): i = (ar_mask.sum(0) > 0).nonzero() @@ -135,8 +135,6 @@ def masked_inplace_autoregression( for s in range(i.min(), i.max() + 1): output = model(mygpt.BracketedSequence(input, s, 1)).x logits = output[:, s] - if forbidden_tokens is not None: - logits = logits.masked_fill(forbidden_tokens, float("-inf")) if args.deterministic_synthesis: t_next = logits.argmax(1) else: @@ -161,176 +159,45 @@ class Task: ###################################################################### -import picoclvr - - -class TaskPicoCLVR(Task): - - # Make a tensor from a list of strings - def tensorize(self, descr): - token_descr = [s.strip().split(" ") for s in descr] - l = max([len(s) for s in token_descr]) - token_descr = [s + [""] * (l - len(s)) for s in token_descr] - id_descr = [[self.token2id[u] for u in s] for s in token_descr] - return torch.tensor(id_descr, device=self.device) - - # Make a list of strings from a tensor - def detensorize(self, x): - return [" ".join([self.id2token[t.item()] for t in r]) for r in x] - - # trim all the tensors in the tuple z to remove as much token from - # left and right in the first tensor. If z is a tuple, all its - # elements are trimed according to the triming for the first - def trim(self, z, token=""): - n = self.token2id[token] - if type(z) == tuple: - x = z[0] - i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0) - a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min() - return tuple([t[:, a:b] for t in z]) - else: - i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0) - a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min() - return z[:, a:b] - - ###################### - # Not the cleanest part of the code - - # Extract the last image of each sequence, from the last - # included, and set to all the tokens from the beginning of - # that image to the end - def excise_last_image(self, input): - t_img, t_nul = self.token2id[""], self.token2id[""] - nb_img_tokens = self.height * self.width + 1 - - input = input.clone() - t = (input == t_img).long() - tail_masks = (t.cumsum(dim=1) == t.sum(dim=1, keepdim=True)).long() - i = (t * tail_masks).nonzero(as_tuple=True) - j = ( - i[0][:, None], - i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :], - ) - images = self.trim(input[j]) - input[j] = t_nul - loss_masks = 1 - tail_masks - input, loss_masks = self.trim((input, loss_masks)) - return input, loss_masks, images - - def add_true_image(self, input, images, loss_masks): - t_nul = self.token2id[""] - nb_img_tokens = self.height * self.width + 1 - input = F.pad(input, (0, nb_img_tokens), value=t_nul) - loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0) - t = (input == t_nul).long() - i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True) - j = ( - i[0][:, None], - i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :], - ) - input[j] = images - loss_masks[j] = 1 - input, loss_masks = self.trim((input, loss_masks)) - return input, loss_masks - - def add_generated_image(self, input, loss_masks, model): - t_img, t_nul = self.token2id[""], self.token2id[""] - nb_img_tokens = self.height * self.width + 1 - - input = F.pad(input, (0, nb_img_tokens), value=t_nul) - loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0) - t = (input == t_nul).long() - i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True) - input[i] = t_img - - j = ( - i[0][:, None], - i[1][:, None] - + 1 - + torch.arange(nb_img_tokens - 1, device=input.device)[None, :], - ) - ar_masks = input.new_zeros(input.size(), dtype=torch.int64) - ar_masks[j] = 1 - forbidden_tokens = ( - torch.arange(self.vocabulary_size(), device=input.device) == t_nul - ) - with torch.autograd.no_grad(): - t = model.training - model.eval() - masked_inplace_autoregression( - model, - self.batch_size, - input, - ar_masks, - forbidden_tokens, - device=self.device, - ) - model.train(t) - - input, loss_masks = self.trim((input, loss_masks)) - - return input, loss_masks - - ###################### - - def __init__( - self, - batch_size, - height, - width, - nb_colors=5, - device=torch.device("cpu"), - pruner_train=None, - pruner_eval=None, - ): - def generate_descr(nb, cache_suffix, pruner): - return picoclvr.generate( - nb, - height=self.height, - width=self.width, - nb_colors=nb_colors, - pruner=pruner, - ) +import maze + + +class TaskMaze(Task): + def map2seq(self, *m): + return torch.cat([x.flatten(1) for x in m], 1) + + def seq2map(self, s): + s = s.reshape(s.size(0), -1, self.height, self.width) + return (s[:, k] for k in range(s.size(1))) + def __init__(self, batch_size, height, width, nb_walls, device=torch.device("cpu")): + self.batch_size = batch_size self.height = height self.width = width - self.batch_size = batch_size self.device = device + nb = args.data_size if args.data_size > 0 else 250000 - self.pruner_train = pruner_train - self.pruner_eval = pruner_eval - - param = { - "nb": nb, - "height": height, - "width": width, - "nb_colors": nb_colors, - "batch_size": batch_size, - "rng_state": list(torch.get_rng_state()), - } - - log_string(f"generating {nb} samples (can take some time)") - self.train_descr = generate_descr( - (nb * 4) // 5, "train", pruner=self.pruner_train + + mazes_train, paths_train = maze.create_maze_data( + (4 * nb) // 5, + height=height, + width=width, + nb_walls=nb_walls, + progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"), + ) + mazes_train, paths_train = mazes_train.to(device), paths_train.to(device) + self.train_input = self.map2seq(mazes_train, paths_train) + self.nb_codes = self.train_input.max() + 1 + + mazes_test, paths_test = maze.create_maze_data( + nb // 5, + height=height, + width=width, + nb_walls=nb_walls, + progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"), ) - self.test_descr = generate_descr((nb * 1) // 5, "test", pruner=None) - - # Build the tokenizer - tokens = {"", ""} - for d in [self.train_descr, self.test_descr]: - for s in d: - for t in s.strip().split(" "): - tokens.add(t) - # make this set a sorted list to get the same tensors given - # the same descr - tokens = list(tokens) - tokens.sort() - self.token2id = dict([(t, n) for n, t in enumerate(tokens)]) - self.id2token = dict([(n, t) for n, t in enumerate(tokens)]) - - # Tokenize the train and test sets - self.train_input = self.tensorize(self.train_descr) - self.test_input = self.tensorize(self.test_descr) + mazes_test, paths_test = mazes_test.to(device), paths_test.to(device) + self.test_input = self.map2seq(mazes_test, paths_test) def batches(self, split="train"): assert split in {"train", "test"} @@ -338,111 +205,45 @@ class TaskPicoCLVR(Task): for batch in tqdm.tqdm( input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}" ): - yield self.trim(batch) + yield batch def vocabulary_size(self): - return len(self.token2id) + return self.nb_codes - def compute_missing_properties(self, n_epoch, model, pruner=None): - - acc_nb_requested_properties = [] - acc_nb_missing_properties = [] - acc_nb_results = 0 - - for input in tqdm.tqdm( - self.test_input.split(self.batch_size), - dynamic_ncols=True, - desc=f"test-properties", - ): - tape, loss_masks, _ = self.excise_last_image(input) - tape, loss_masks = self.add_generated_image(tape, loss_masks, model) - result_descr = self.detensorize(tape) - np = picoclvr.nb_properties( - result_descr, - height=self.height, - width=self.width, - pruner=pruner, - ) - nb_requested_properties, _, nb_missing_properties = zip(*np) - acc_nb_requested_properties += nb_requested_properties - acc_nb_missing_properties += nb_missing_properties - acc_nb_results += len(result_descr) - - nb_requested_properties = sum(acc_nb_requested_properties) - nb_missing_properties = sum(acc_nb_missing_properties) - - prefix = "" if pruner is None else "pruned_" - log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}") - log_string( - f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}" - ) - log_string( - f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%" - ) + def compute_error(self, model, split="train"): + nb_total, nb_correct = 0, 0 + for input in task.batches(split): + result = input.clone() + ar_mask = result.new_zeros(result.size()) + ar_mask[:, self.height * self.width :] = 1 + masked_inplace_autoregression(model, self.batch_size, result, ar_mask) + mazes, paths = self.seq2map(result) + nb_correct += maze.path_correctness(mazes, paths).long().sum() + nb_total += mazes.size(0) - ###################################################################### + return nb_total, nb_correct def produce_results(self, n_epoch, model): - - self.compute_missing_properties(n_epoch, model) - - if self.pruner_eval is not None: - self.compute_missing_properties(n_epoch, model, self.pruner_eval) - - nb_tokens_to_generate = self.height * self.width + 3 - result_descr = [] - nb_per_primer = 8 - primer = [] - - for primer_descr in [ - "red above green green top blue right of red", - "there is red there is yellow there is blue", - "red below yellow yellow below green green below blue red right yellow left green right blue left", - "green bottom yellow bottom green left of blue yellow right of blue blue top", - ]: - primer += [primer_descr] * nb_per_primer - - tape = self.tensorize(primer) - loss_masks = 1 - (tape == self.token2id[""]).long() - tape, loss_masks = self.add_generated_image(tape, loss_masks, model) - result_descr = self.detensorize(tape) - - np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width) - - acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np) - acc_nb_results = len(result_descr) - - nb_requested_properties = sum(acc_nb_requested_properties) - nb_missing_properties = sum(acc_nb_missing_properties) - - prefix = "demo_" - log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}") + train_nb_total, train_nb_correct = self.compute_error(model, "train") log_string( - f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}" + f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" ) + + test_nb_total, test_nb_correct = self.compute_error(model, "test") log_string( - f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%" + f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" ) - img = picoclvr.descr2img(result_descr, height=self.height, width=self.width) + input = self.test_input[:32] + result = input.clone() + ar_mask = result.new_zeros(result.size()) - if img.dim() == 5: - if img.size(1) == 1: - img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64) - else: - img = torch.cat( - [ - torchvision.utils.make_grid(x, padding=1, pad_value=64)[None] - for x in img - ], - 0, - ) - - image_name = os.path.join(args.result_dir, f"result_{n_epoch:04d}.png") - torchvision.utils.save_image( - img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=1.0 - ) - log_string(f"wrote {image_name}") + ar_mask[:, self.height * self.width :] = 1 + masked_inplace_autoregression(model, self.batch_size, result, ar_mask) + + mazes, paths = self.seq2map(input) + _, predicted_paths = self.seq2map(result) + maze.save_image(f"result_{n_epoch:04d}.png", mazes, paths, predicted_paths) ###################################################################### @@ -450,24 +251,15 @@ class TaskPicoCLVR(Task): log_string(f"device {device}") -def pruner_horizontal_green(p): - return not ("green" in p and ("left" in p or "right" in p)) - - -task = TaskPicoCLVR( +task = TaskMaze( batch_size=args.batch_size, - height=args.height, - width=args.width, - nb_colors=args.nb_colors, + height=args.world_height, + width=args.world_width, + nb_walls=args.world_nb_walls, device=device, - pruner_train=pruner_horizontal_green - if args.prune_properties in {"train+eval"} - else None, - pruner_eval=(lambda p: not pruner_horizontal_green(p)) - if args.prune_properties in {"train+eval", "eval"} - else None, ) + vocabulary_size = task.vocabulary_size() log_string(f"vocabulary_size {vocabulary_size}") diff --git a/maze.py b/maze.py index 2c44319..f4a4840 100755 --- a/maze.py +++ b/maze.py @@ -129,14 +129,12 @@ def mark_path(walls, i, j, goal_i, goal_j): assert n < nmax -def valid_paths(mazes, paths): +def path_correctness(mazes, paths): still_ok = (mazes - (paths * (paths < 4))).view(mazes.size(0), -1).abs().sum(1) == 0 reached = still_ok.new_zeros(still_ok.size()) current, pred_current = paths.clone(), paths.new_zeros(paths.size()) goal = (mazes == v_goal).long() while not pred_current.equal(current): - # print(current) - # print(f'{still_ok=} {reached=}') pred_current.copy_(current) u = (current == v_start).long() possible_next = ( @@ -157,12 +155,14 @@ def valid_paths(mazes, paths): ###################################################################### -def create_maze_data(nb, h=11, w=17, nb_walls=8, dist_min=-1): - mazes = torch.empty(nb, h, w, dtype=torch.int64) - paths = torch.empty(nb, h, w, dtype=torch.int64) +def create_maze_data( + nb, height=11, width=17, nb_walls=8, dist_min=10, progress_bar=lambda x: x +): + mazes = torch.empty(nb, height, width, dtype=torch.int64) + paths = torch.empty(nb, height, width, dtype=torch.int64) - for n in range(nb): - maze = create_maze(h, w, nb_walls) + for n in progress_bar(range(nb)): + maze = create_maze(height, width, nb_walls) i = (1 - maze).nonzero() while True: start, goal = i[torch.randperm(i.size(0))[:2]] @@ -185,8 +185,8 @@ def create_maze_data(nb, h=11, w=17, nb_walls=8, dist_min=-1): ###################################################################### -def save_image(name, mazes, paths): - mazes, paths = mazes.cpu(), paths.cpu() +def save_image(name, mazes, target_paths, predicted_paths=None): + mazes, target_paths = mazes.cpu(), target_paths.cpu() colors = torch.tensor( [ @@ -199,20 +199,35 @@ def save_image(name, mazes, paths): ) mazes = colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2) - paths = colors[paths.reshape(-1)].reshape(paths.size() + (-1,)).permute(0, 3, 1, 2) + target_paths = ( + colors[target_paths.reshape(-1)] + .reshape(target_paths.size() + (-1,)) + .permute(0, 3, 1, 2) + ) + img = torch.cat((mazes.unsqueeze(1), target_paths.unsqueeze(1)), 1) + + if predicted_paths is not None: + predicted_paths = predicted_paths.cpu() + predicted_paths = ( + colors[predicted_paths.reshape(-1)] + .reshape(predicted_paths.size() + (-1,)) + .permute(0, 3, 1, 2) + ) + img = torch.cat((img, predicted_paths.unsqueeze(1)), 1) - img = torch.cat((mazes.unsqueeze(1), paths.unsqueeze(1)), 1) img = img.reshape((-1,) + img.size()[2:]).float() / 255.0 - torchvision.utils.save_image(img, name, padding=1, pad_value=0.5, nrow=8) + torchvision.utils.save_image(img, name, padding=1, pad_value=0.85, nrow=6) ###################################################################### if __name__ == "__main__": - mazes, paths = create_maze_data(32, dist_min=10) - save_image("test.png", mazes, paths) - print(valid_paths(mazes, paths)) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + mazes, paths = create_maze_data(8) + mazes, paths = mazes.to(device), paths.to(device) + save_image("test.png", mazes, paths, paths) + print(path_correctness(mazes, paths)) ###################################################################### -- 2.20.1