From: François Fleuret Date: Sat, 11 Mar 2023 21:00:43 +0000 (+0100) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=beaver.git;a=commitdiff_plain;h=483494965a97df9a09cd572246d1f3a4dc7248de Update. --- diff --git a/beaver.py b/beaver.py index 920a446..a289867 100755 --- a/beaver.py +++ b/beaver.py @@ -75,11 +75,11 @@ parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") ############################## # maze options -parser.add_argument("--world_height", type=int, default=13) +parser.add_argument("--maze_height", type=int, default=13) -parser.add_argument("--world_width", type=int, default=21) +parser.add_argument("--maze_width", type=int, default=21) -parser.add_argument("--world_nb_walls", type=int, default=15) +parser.add_argument("--maze_nb_walls", type=int, default=15) ###################################################################### @@ -262,7 +262,13 @@ class TaskMaze(Task): mazes, paths = self.seq2map(input) _, predicted_paths = self.seq2map(result) - maze.save_image(f"result_{n_epoch:04d}.png", mazes, paths, predicted_paths) + maze.save_image( + f"result_{n_epoch:04d}.png", + mazes, + paths, + predicted_paths, + maze.path_correctness(mazes, predicted_paths), + ) model.train(t) @@ -276,9 +282,9 @@ task = TaskMaze( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, - height=args.world_height, - width=args.world_width, - nb_walls=args.world_nb_walls, + height=args.maze_height, + width=args.maze_width, + nb_walls=args.maze_nb_walls, device=device, ) diff --git a/maze.py b/maze.py index f4a4840..e377d2f 100755 --- a/maze.py +++ b/maze.py @@ -185,7 +185,7 @@ def create_maze_data( ###################################################################### -def save_image(name, mazes, target_paths, predicted_paths=None): +def save_image(name, mazes, target_paths, predicted_paths=None, path_correct=None): mazes, target_paths = mazes.cpu(), target_paths.cpu() colors = torch.tensor( @@ -204,7 +204,7 @@ def save_image(name, mazes, target_paths, predicted_paths=None): .reshape(target_paths.size() + (-1,)) .permute(0, 3, 1, 2) ) - img = torch.cat((mazes.unsqueeze(1), target_paths.unsqueeze(1)), 1) + imgs = torch.cat((mazes.unsqueeze(1), target_paths.unsqueeze(1)), 1) if predicted_paths is not None: predicted_paths = predicted_paths.cpu() @@ -213,11 +213,29 @@ def save_image(name, mazes, target_paths, predicted_paths=None): .reshape(predicted_paths.size() + (-1,)) .permute(0, 3, 1, 2) ) - img = torch.cat((img, predicted_paths.unsqueeze(1)), 1) - - img = img.reshape((-1,) + img.size()[2:]).float() / 255.0 - - torchvision.utils.save_image(img, name, padding=1, pad_value=0.85, nrow=6) + imgs = torch.cat((imgs, predicted_paths.unsqueeze(1)), 1) + + # NxKxCxHxW + if path_correct is None: + path_correct = torch.zeros(imgs.size(0)) <= 1 + path_correct = path_correct.cpu().long().view(-1, 1, 1, 1) + img = torch.tensor([224, 224, 224]).view(1, -1, 1, 1) * path_correct + torch.tensor( + [255, 0, 0] + ).view(1, -1, 1, 1) * (1 - path_correct) + img = img.expand( + -1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4)) + ).clone() + for k in range(imgs.size(1)): + img[ + :, + :, + 1 : 1 + imgs.size(3), + 1 + k * (1 + imgs.size(4)) : 1 + k * (1 + imgs.size(4)) + imgs.size(4), + ] = imgs[:, k] + + img = img.float() / 255.0 + + torchvision.utils.save_image(img, name, nrow=4, padding=1, pad_value=224.0 / 256) ######################################################################