X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=maze.py;h=4953d10de69e9b6bbfff2dac66a4b769964a5080;hb=HEAD;hp=8ac9fcef0495fbefc6c11f1652367fb44d916f40;hpb=4395f9a90218819997c706de9505cda1c86ad507;p=mygptrnn.git diff --git a/maze.py b/maze.py index 8ac9fce..4953d10 100755 --- a/maze.py +++ b/maze.py @@ -231,9 +231,14 @@ def save_image( [0, 255, 0], # start [127, 127, 255], # goal [255, 0, 0], # path + [128, 128, 128], # error ] ) + def safe_colors(x): + m = (x >= 0).long() * (x < colors.size(0) - 1).long() + return colors[x * m + (colors.size(0) - 1) * (1 - m)] + mazes = mazes.cpu() c_mazes = ( @@ -256,7 +261,7 @@ def save_image( if predicted_paths is not None: predicted_paths = predicted_paths.cpu() c_predicted_paths = ( - colors[predicted_paths.reshape(-1)] + safe_colors(predicted_paths.reshape(-1)) .reshape(predicted_paths.size() + (-1,)) .permute(0, 3, 1, 2) ) @@ -282,8 +287,6 @@ def save_image( -1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4)) ).clone() - print(f"{img.size()=} {imgs.size()=}") - for k in range(imgs.size(1)): img[ :,