- paths = colors[paths.reshape(-1)].reshape(paths.size() + (-1,)).permute(0, 3, 1, 2)
+ target_paths = (
+ colors[target_paths.reshape(-1)]
+ .reshape(target_paths.size() + (-1,))
+ .permute(0, 3, 1, 2)
+ )
+ img = torch.cat((mazes.unsqueeze(1), target_paths.unsqueeze(1)), 1)
+
+ if predicted_paths is not None:
+ predicted_paths = predicted_paths.cpu()
+ predicted_paths = (
+ colors[predicted_paths.reshape(-1)]
+ .reshape(predicted_paths.size() + (-1,))
+ .permute(0, 3, 1, 2)
+ )
+ img = torch.cat((img, predicted_paths.unsqueeze(1)), 1)