model.eval()
mazes = task.test_input[:32].clone()
mazes[:, task.height * task.width :] = 0
+ policies = task.test_policies[:32]
+ targets = maze.stationary_densities(
+ mazes[:, : task.height * task.width].view(-1, task.height, task.width),
+ policies.view(-1, 4, task.height, task.width),
+ ).flatten(-2)
output = eval_mygpt(model, mazes, prompt_len=task.height * task.width)
output = F.softmax(output, dim=2)
print(f"{output.size()=}")
-1, task.height, task.width
)
mazes = mazes[:, : task.height * task.width].reshape(-1, task.height, task.width)
+ targets = targets.reshape(-1, task.height, task.width)
paths = task.test_input[:32, task.height * task.width :].reshape(
-1, task.height, task.width
)
maze.save_image(
os.path.join(args.result_dir, filename),
mazes=mazes,
- target_paths=paths,
+ # target_paths=paths,
score_paths=proba_path,
- # score_truth=targets,
+ score_truth=targets,
)
log_string(f"wrote {filename}")