X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=beaver.py;h=49cb1f672075795f84021e9e8a5d6573a6d6f56e;hb=10c1e2159ef57a55724fb1753381dc30e8aa77c2;hp=920a446f920e6cd2beb67ac0df96457bfac55225;hpb=61cd7a140e44ccb966bad941fa31e395e51e50e2;p=beaver.git diff --git a/beaver.py b/beaver.py index 920a446..49cb1f6 100755 --- a/beaver.py +++ b/beaver.py @@ -26,9 +26,7 @@ else: ###################################################################### -parser = argparse.ArgumentParser( - description="An implementation of GPT with cache to solve a toy geometric reasoning task." -) +parser = argparse.ArgumentParser(description="A maze shortest path solving with a GPT.") parser.add_argument("--log_filename", type=str, default="train.log") @@ -66,6 +64,8 @@ parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--deterministic_synthesis", action="store_true", default=False) +parser.add_argument("--random_regression_order", action="store_true", default=False) + parser.add_argument("--no_checkpoint", action="store_true", default=False) parser.add_argument("--overwrite_results", action="store_true", default=False) @@ -75,11 +75,20 @@ parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") ############################## # maze options -parser.add_argument("--world_height", type=int, default=13) +parser.add_argument("--maze_height", type=int, default=13) + +parser.add_argument("--maze_width", type=int, default=21) + +parser.add_argument("--maze_nb_walls", type=int, default=15) + +############################## +# one-shot prediction -parser.add_argument("--world_width", type=int, default=21) +parser.add_argument("--oneshot", action="store_true", default=False) -parser.add_argument("--world_nb_walls", type=int, default=15) +parser.add_argument("--oneshot_input", type=str, default="head") + +parser.add_argument("--oneshot_output", type=str, default="trace") ###################################################################### @@ -122,20 +131,38 @@ for n in vars(args): ###################################################################### +def random_order(result, fixed_len): + if args.random_regression_order: + order = torch.rand(result.size(), device=result.device) + order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=order.device) + return order.sort(1).indices + else: + return torch.arange(result.size(1)).unsqueeze(0).expand(result.size(0), -1) + + +def shuffle(x, order, reorder=False): + if x.dim() == 3: + order = order.unsqueeze(-1).expand(-1, -1, x.size(-1)) + if reorder: + y = x.new(x.size()) + y.scatter_(1, order, x) + return y + else: + return x.gather(1, order) + + # ar_mask is a Boolean matrix of same shape as input, with 1s on the # tokens that should be generated -def masked_inplace_autoregression(model, batch_size, input, ar_mask): - +def masked_inplace_autoregression(model, batch_size, input, ar_mask, order=None): for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)): i = (ar_mask.sum(0) > 0).nonzero() if i.min() > 0: - model( - mygpt.BracketedSequence(input, 0, i.min()) - ) # Needed to initialize the model's cache + # Needed to initialize the model's cache + model(mygpt.BracketedSequence(input, 0, i.min()), order=order) for s in range(i.min(), i.max() + 1): - output = model(mygpt.BracketedSequence(input, s, 1)).x + output = model(mygpt.BracketedSequence(input, s, 1), order=order).x logits = output[:, s] if args.deterministic_synthesis: t_next = logits.argmax(1) @@ -148,8 +175,154 @@ def masked_inplace_autoregression(model, batch_size, input, ar_mask): ###################################################################### +def compute_perplexity(model, split="train"): + with torch.autograd.no_grad(): + t = model.training + model.eval() + + nb_samples, acc_loss = 0, 0.0 + + for input in task.batches(split=split): + input = input.to(device) + order = random_order(input, task.height * task.width) + input = shuffle(input, order) + output = model(mygpt.BracketedSequence(input), order=order).x + loss = F.cross_entropy(output.transpose(1, 2), input) + acc_loss += loss.item() * input.size(0) + nb_samples += input.size(0) + + model.train(t) + + return math.exp(min(100, acc_loss / nb_samples)) + + +###################################################################### + + +def oneshot_policy_loss(mazes, output, policies, height, width): + masks = (mazes == maze.v_empty).unsqueeze(-1) + targets = policies.permute(0, 2, 1) * masks + output = output * masks + return -(output.log_softmax(-1) * targets).sum() / masks.sum() + + +def oneshot_trace_loss(mazes, output, policies, height, width): + masks = mazes == maze.v_empty + targets = maze.stationary_densities( + mazes.view(-1, height, width), policies.view(-1, 4, height, width) + ).flatten(-2) + targets = targets * masks + output = output.squeeze(-1) * masks + return (output - targets).abs().sum() / masks.sum() + + +def oneshot(gpt, task): + t = gpt.training + gpt.eval() + + if args.oneshot_input == "head": + dim_in = args.dim_model + elif args.oneshot_input == "deep": + dim_in = args.dim_model * args.nb_blocks * 2 + else: + raise ValueError(f"{args.oneshot_input=}") + + if args.oneshot_output == "policy": + dim_out = 4 + compute_loss = oneshot_policy_loss + elif args.oneshot_output == "trace": + dim_out = 1 + compute_loss = oneshot_trace_loss + else: + raise ValueError(f"{args.oneshot_output=}") + + model = nn.Sequential( + nn.Linear(dim_in, args.dim_model), + nn.ReLU(), + nn.Linear(args.dim_model, args.dim_model), + nn.ReLU(), + nn.Linear(args.dim_model, dim_out), + ).to(device) + + for n_epoch in range(args.nb_epochs): + learning_rate = learning_rate_schedule[n_epoch] + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) + + acc_train_loss, nb_train_samples = 0, 0 + for mazes, policies in task.policy_batches(split="train"): + order = random_order(mazes, task.height * task.width) + x = shuffle(mazes, order) + x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x + output_gpt = shuffle(x, order, reorder=True) + output = model(output_gpt) + + loss = compute_loss(mazes, output, policies, task.height, task.width) + acc_train_loss += loss.item() * mazes.size(0) + nb_train_samples += mazes.size(0) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + acc_test_loss, nb_test_samples = 0, 0 + for mazes, policies in task.policy_batches(split="test"): + order = random_order(mazes, task.height * task.width) + x = shuffle(mazes, order) + x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x + output_gpt = shuffle(x, order, reorder=True) + output = model(output_gpt) + loss = compute_loss(mazes, output, policies, task.height, task.width) + acc_test_loss += loss.item() * mazes.size(0) + nb_test_samples += mazes.size(0) + + log_string( + f"diff_ce {n_epoch} train {acc_train_loss/nb_train_samples} test {acc_test_loss/nb_test_samples}" + ) + + # ------------------- + mazes = task.test_input[:32, : task.height * task.width] + policies = task.test_policies[:32] + order = random_order(mazes, task.height * task.width) + x = shuffle(mazes, order) + x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x + output_gpt = shuffle(x, order, reorder=True) + output = model(output_gpt) + if args.oneshot_output == "policy": + targets = policies.permute(0, 2, 1) + scores = ( + (F.one_hot(output.argmax(-1), num_classes=4) * targets).sum(-1) == 0 + ).float() + elif args.oneshot_output == "trace": + targets = maze.stationary_densities( + mazes.view(-1, task.height, task.width), + policies.view(-1, 4, task.height, task.width), + ).flatten(-2) + scores = output + else: + raise ValueError(f"{args.oneshot_output=}") + + scores = scores.reshape(-1, task.height, task.width) + mazes = mazes.reshape(-1, task.height, task.width) + targets = targets.reshape(-1, task.height, task.width) + maze.save_image( + os.path.join( + args.result_dir, + f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png", + ), + mazes=mazes, + score_paths=scores, + score_truth=targets, + ) + # ------------------- + + gpt.train(t) + + +###################################################################### + + class Task: - def batches(self, split="train"): + def batches(self, split="train", nb_to_use=-1, desc=None): pass def vocabulary_size(self): @@ -187,34 +360,57 @@ class TaskMaze(Task): self.width = width self.device = device - mazes_train, paths_train = maze.create_maze_data( + train_mazes, train_paths, train_policies = maze.create_maze_data( nb_train_samples, height=height, width=width, nb_walls=nb_walls, progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"), ) - mazes_train, paths_train = mazes_train.to(device), paths_train.to(device) - self.train_input = self.map2seq(mazes_train, paths_train) - self.nb_codes = self.train_input.max() + 1 + self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device)) + self.train_policies = train_policies.flatten(-2).to(device) - mazes_test, paths_test = maze.create_maze_data( + test_mazes, test_paths, test_policies = maze.create_maze_data( nb_test_samples, height=height, width=width, nb_walls=nb_walls, progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"), ) - mazes_test, paths_test = mazes_test.to(device), paths_test.to(device) - self.test_input = self.map2seq(mazes_test, paths_test) + self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device)) + self.test_policies = test_policies.flatten(-2).to(device) - def batches(self, split="train", nb_to_use=-1): + self.nb_codes = self.train_input.max() + 1 + + def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input if nb_to_use > 0: input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" for batch in tqdm.tqdm( - input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}" + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + yield batch + + def policy_batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + policies = self.train_policies if split == "train" else self.test_policies + input = input[:, : self.height * self.width] + policies = policies * (input != maze.v_wall)[:, None] + + if nb_to_use > 0: + input = input[:nb_to_use] + policies = policies[:nb_to_use] + + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + zip(input.split(self.batch_size), policies.split(self.batch_size)), + dynamic_ncols=True, + desc=desc, ): yield batch @@ -227,7 +423,12 @@ class TaskMaze(Task): result = input.clone() ar_mask = result.new_zeros(result.size()) ar_mask[:, self.height * self.width :] = 1 - masked_inplace_autoregression(model, self.batch_size, result, ar_mask) + result *= 1 - ar_mask + order = random_order(result, self.height * self.width) + masked_inplace_autoregression( + model, self.batch_size, result, ar_mask, order=order + ) + result = shuffle(result, order, reorder=True) mazes, paths = self.seq2map(result) nb_correct += maze.path_correctness(mazes, paths).long().sum() nb_total += mazes.size(0) @@ -256,13 +457,19 @@ class TaskMaze(Task): input = self.test_input[:32] result = input.clone() ar_mask = result.new_zeros(result.size()) - ar_mask[:, self.height * self.width :] = 1 + result *= 1 - ar_mask masked_inplace_autoregression(model, self.batch_size, result, ar_mask) mazes, paths = self.seq2map(input) _, predicted_paths = self.seq2map(result) - maze.save_image(f"result_{n_epoch:04d}.png", mazes, paths, predicted_paths) + maze.save_image( + os.path.join(args.result_dir, f"result_{n_epoch:04d}.png"), + mazes=mazes, + target_paths=paths, + predicted_paths=predicted_paths, + path_correct=maze.path_correctness(mazes, predicted_paths), + ) model.train(t) @@ -276,9 +483,9 @@ task = TaskMaze( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, - height=args.world_height, - width=args.world_width, - nb_walls=args.world_nb_walls, + height=args.maze_height, + width=args.maze_width, + nb_walls=args.maze_nb_walls, device=device, ) @@ -333,8 +540,6 @@ else: ###################################################################### -nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default - token_count = 0 for input in task.batches(split="train"): token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1)) @@ -368,13 +573,20 @@ log_string(f"learning_rate_schedule {learning_rate_schedule}") ############################## -nb_samples_seen = 0 +if nb_epochs_finished >= args.nb_epochs: + n_epoch = nb_epochs_finished + train_perplexity = compute_perplexity(model, split="train") + test_perplexity = compute_perplexity(model, split="test") + + log_string( + f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}" + ) -if nb_epochs_finished >= nb_epochs: - task.produce_results(nb_epochs_finished, model) + task.produce_results(n_epoch, model) -for n_epoch in range(nb_epochs_finished, nb_epochs): +############################## +for n_epoch in range(nb_epochs_finished, args.nb_epochs): learning_rate = learning_rate_schedule[n_epoch] log_string(f"learning_rate {learning_rate}") @@ -386,7 +598,7 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): elif args.optim == "adamw": optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) else: - raise ValueError(f"Unknown optimizer {args.optim}.") + raise ValueError(f"{args.optim=}") model.train() @@ -394,41 +606,25 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): for input in task.batches(split="train"): input = input.to(device) - output = model(mygpt.BracketedSequence(input)).x + order = random_order(input, task.height * task.width) + input = shuffle(input, order) + output = model(mygpt.BracketedSequence(input), order=order).x loss = F.cross_entropy(output.transpose(1, 2), input) acc_train_loss += loss.item() * input.size(0) nb_train_samples += input.size(0) - nb_samples_seen += input.size(0) optimizer.zero_grad() loss.backward() optimizer.step() - with torch.autograd.no_grad(): - - model.eval() + train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) + test_perplexity = compute_perplexity(model, split="test") - nb_test_samples, acc_test_loss = 0, 0.0 - - for input in task.batches(split="test"): - input = input.to(device) + log_string( + f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}" + ) - # input, loss_masks, true_images = task.excise_last_image(input) - # input, loss_masks = task.add_true_image(input, true_images, loss_masks) - - output = model(mygpt.BracketedSequence(input)).x - loss = F.cross_entropy(output.transpose(1, 2), input) - acc_test_loss += loss.item() * input.size(0) - nb_test_samples += input.size(0) - - train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) - test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) - - log_string( - f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}" - ) - - task.produce_results(n_epoch, model) + task.produce_results(n_epoch, model) checkpoint = { "nb_epochs_finished": n_epoch + 1, @@ -444,3 +640,8 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): log_string(f"saved checkpoint {checkpoint_name}") ###################################################################### + +if args.oneshot: + oneshot(model, task) + +######################################################################