X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=beaver.py;h=c68fe7660afc7d9eeb01379576bf9eb0dc2f51c4;hb=3518f58472ceb6cf7ea3cdb62aabc7a368501348;hp=920a446f920e6cd2beb67ac0df96457bfac55225;hpb=61cd7a140e44ccb966bad941fa31e395e51e50e2;p=beaver.git diff --git a/beaver.py b/beaver.py index 920a446..c68fe76 100755 --- a/beaver.py +++ b/beaver.py @@ -75,11 +75,11 @@ parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") ############################## # maze options -parser.add_argument("--world_height", type=int, default=13) +parser.add_argument("--maze_height", type=int, default=13) -parser.add_argument("--world_width", type=int, default=21) +parser.add_argument("--maze_width", type=int, default=21) -parser.add_argument("--world_nb_walls", type=int, default=15) +parser.add_argument("--maze_nb_walls", type=int, default=15) ###################################################################### @@ -227,6 +227,7 @@ class TaskMaze(Task): result = input.clone() ar_mask = result.new_zeros(result.size()) ar_mask[:, self.height * self.width :] = 1 + result *= 1-ar_mask masked_inplace_autoregression(model, self.batch_size, result, ar_mask) mazes, paths = self.seq2map(result) nb_correct += maze.path_correctness(mazes, paths).long().sum() @@ -256,13 +257,19 @@ class TaskMaze(Task): input = self.test_input[:32] result = input.clone() ar_mask = result.new_zeros(result.size()) - ar_mask[:, self.height * self.width :] = 1 + result *= 1-ar_mask masked_inplace_autoregression(model, self.batch_size, result, ar_mask) mazes, paths = self.seq2map(input) _, predicted_paths = self.seq2map(result) - maze.save_image(f"result_{n_epoch:04d}.png", mazes, paths, predicted_paths) + maze.save_image( + f"result_{n_epoch:04d}.png", + mazes, + paths, + predicted_paths, + maze.path_correctness(mazes, predicted_paths), + ) model.train(t) @@ -276,9 +283,9 @@ task = TaskMaze( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, - height=args.world_height, - width=args.world_width, - nb_walls=args.world_nb_walls, + height=args.maze_height, + width=args.maze_width, + nb_walls=args.maze_nb_walls, device=device, )