X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=beaver.py;h=517f29a5f1702ea7630600a44126633603d25125;hb=27bb2d1ab23422f26b05f88b4e0573deeb075cd2;hp=920a446f920e6cd2beb67ac0df96457bfac55225;hpb=61cd7a140e44ccb966bad941fa31e395e51e50e2;p=beaver.git diff --git a/beaver.py b/beaver.py index 920a446..517f29a 100755 --- a/beaver.py +++ b/beaver.py @@ -26,9 +26,7 @@ else: ###################################################################### -parser = argparse.ArgumentParser( - description="An implementation of GPT with cache to solve a toy geometric reasoning task." -) +parser = argparse.ArgumentParser(description="A maze shortest path solving with a GPT.") parser.add_argument("--log_filename", type=str, default="train.log") @@ -75,11 +73,11 @@ parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") ############################## # maze options -parser.add_argument("--world_height", type=int, default=13) +parser.add_argument("--maze_height", type=int, default=13) -parser.add_argument("--world_width", type=int, default=21) +parser.add_argument("--maze_width", type=int, default=21) -parser.add_argument("--world_nb_walls", type=int, default=15) +parser.add_argument("--maze_nb_walls", type=int, default=15) ###################################################################### @@ -131,9 +129,8 @@ def masked_inplace_autoregression(model, batch_size, input, ar_mask): for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)): i = (ar_mask.sum(0) > 0).nonzero() if i.min() > 0: - model( - mygpt.BracketedSequence(input, 0, i.min()) - ) # Needed to initialize the model's cache + # Needed to initialize the model's cache + model(mygpt.BracketedSequence(input, 0, i.min())) for s in range(i.min(), i.max() + 1): output = model(mygpt.BracketedSequence(input, s, 1)).x logits = output[:, s] @@ -196,7 +193,6 @@ class TaskMaze(Task): ) mazes_train, paths_train = mazes_train.to(device), paths_train.to(device) self.train_input = self.map2seq(mazes_train, paths_train) - self.nb_codes = self.train_input.max() + 1 mazes_test, paths_test = maze.create_maze_data( nb_test_samples, @@ -208,6 +204,8 @@ class TaskMaze(Task): mazes_test, paths_test = mazes_test.to(device), paths_test.to(device) self.test_input = self.map2seq(mazes_test, paths_test) + self.nb_codes = self.train_input.max() + 1 + def batches(self, split="train", nb_to_use=-1): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input @@ -227,6 +225,7 @@ class TaskMaze(Task): result = input.clone() ar_mask = result.new_zeros(result.size()) ar_mask[:, self.height * self.width :] = 1 + result *= 1 - ar_mask masked_inplace_autoregression(model, self.batch_size, result, ar_mask) mazes, paths = self.seq2map(result) nb_correct += maze.path_correctness(mazes, paths).long().sum() @@ -256,13 +255,19 @@ class TaskMaze(Task): input = self.test_input[:32] result = input.clone() ar_mask = result.new_zeros(result.size()) - ar_mask[:, self.height * self.width :] = 1 + result *= 1 - ar_mask masked_inplace_autoregression(model, self.batch_size, result, ar_mask) mazes, paths = self.seq2map(input) _, predicted_paths = self.seq2map(result) - maze.save_image(f"result_{n_epoch:04d}.png", mazes, paths, predicted_paths) + maze.save_image( + os.path.join(args.result_dir, f"result_{n_epoch:04d}.png"), + mazes, + paths, + predicted_paths, + maze.path_correctness(mazes, predicted_paths), + ) model.train(t) @@ -276,9 +281,9 @@ task = TaskMaze( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, - height=args.world_height, - width=args.world_width, - nb_walls=args.world_nb_walls, + height=args.maze_height, + width=args.maze_width, + nb_walls=args.maze_nb_walls, device=device, ) @@ -413,9 +418,6 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): for input in task.batches(split="test"): input = input.to(device) - # input, loss_masks, true_images = task.excise_last_image(input) - # input, loss_masks = task.add_true_image(input, true_images, loss_masks) - output = model(mygpt.BracketedSequence(input)).x loss = F.cross_entropy(output.transpose(1, 2), input) acc_test_loss += loss.item() * input.size(0)