[0, 255, 0], # start
[127, 127, 255], # goal
[255, 0, 0], # path
+ [128, 128, 128], # error
]
)
+ def safe_colors(x):
+ m = (x >= 0).long() * (x < colors.size(0) - 1).long()
+ return colors[x * m + (colors.size(0) - 1) * (1 - m)]
+
mazes = mazes.cpu()
c_mazes = (
if predicted_paths is not None:
predicted_paths = predicted_paths.cpu()
c_predicted_paths = (
- colors[predicted_paths.reshape(-1)]
+ safe_colors(predicted_paths.reshape(-1))
.reshape(predicted_paths.size() + (-1,))
.permute(0, 3, 1, 2)
)
-1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4))
).clone()
- print(f"{img.size()=} {imgs.size()=}")
-
for k in range(imgs.size(1)):
img[
:,