##############################
# maze options
-parser.add_argument("--world_height", type=int, default=13)
+parser.add_argument("--maze_height", type=int, default=13)
-parser.add_argument("--world_width", type=int, default=21)
+parser.add_argument("--maze_width", type=int, default=21)
-parser.add_argument("--world_nb_walls", type=int, default=15)
+parser.add_argument("--maze_nb_walls", type=int, default=15)
######################################################################
mazes, paths = self.seq2map(input)
_, predicted_paths = self.seq2map(result)
- maze.save_image(f"result_{n_epoch:04d}.png", mazes, paths, predicted_paths)
+ maze.save_image(
+ f"result_{n_epoch:04d}.png",
+ mazes,
+ paths,
+ predicted_paths,
+ maze.path_correctness(mazes, predicted_paths),
+ )
model.train(t)
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
batch_size=args.batch_size,
- height=args.world_height,
- width=args.world_width,
- nb_walls=args.world_nb_walls,
+ height=args.maze_height,
+ width=args.maze_width,
+ nb_walls=args.maze_nb_walls,
device=device,
)
######################################################################
-def save_image(name, mazes, target_paths, predicted_paths=None):
+def save_image(name, mazes, target_paths, predicted_paths=None, path_correct=None):
mazes, target_paths = mazes.cpu(), target_paths.cpu()
colors = torch.tensor(
.reshape(target_paths.size() + (-1,))
.permute(0, 3, 1, 2)
)
- img = torch.cat((mazes.unsqueeze(1), target_paths.unsqueeze(1)), 1)
+ imgs = torch.cat((mazes.unsqueeze(1), target_paths.unsqueeze(1)), 1)
if predicted_paths is not None:
predicted_paths = predicted_paths.cpu()
.reshape(predicted_paths.size() + (-1,))
.permute(0, 3, 1, 2)
)
- img = torch.cat((img, predicted_paths.unsqueeze(1)), 1)
-
- img = img.reshape((-1,) + img.size()[2:]).float() / 255.0
-
- torchvision.utils.save_image(img, name, padding=1, pad_value=0.85, nrow=6)
+ imgs = torch.cat((imgs, predicted_paths.unsqueeze(1)), 1)
+
+ # NxKxCxHxW
+ if path_correct is None:
+ path_correct = torch.zeros(imgs.size(0)) <= 1
+ path_correct = path_correct.cpu().long().view(-1, 1, 1, 1)
+ img = torch.tensor([224, 224, 224]).view(1, -1, 1, 1) * path_correct + torch.tensor(
+ [255, 0, 0]
+ ).view(1, -1, 1, 1) * (1 - path_correct)
+ img = img.expand(
+ -1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4))
+ ).clone()
+ for k in range(imgs.size(1)):
+ img[
+ :,
+ :,
+ 1 : 1 + imgs.size(3),
+ 1 + k * (1 + imgs.size(4)) : 1 + k * (1 + imgs.size(4)) + imgs.size(4),
+ ] = imgs[:, k]
+
+ img = img.float() / 255.0
+
+ torchvision.utils.save_image(img, name, nrow=4, padding=1, pad_value=224.0 / 256)
######################################################################