+ # -------------------
+ mazes = task.test_input[:32, : task.height * task.width]
+ policies = task.test_policies[:32]
+ output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x
+ output = model(output_gpt)
+ if args.oneshot_output == "policy":
+ targets = policies.permute(0, 2, 1)
+ scores = (
+ (F.one_hot(output.argmax(-1), num_classes=4) * targets).sum(-1) == 0
+ ).float()
+ elif args.oneshot_output == "trace":
+ targets = maze.stationary_densities(
+ mazes.view(-1, task.height, task.width),
+ policies.view(-1, 4, task.height, task.width),
+ ).flatten(-2)
+ scores = output.flatten(-2)
+ else:
+ raise ValueError(f"{args.oneshot_output=}")
+
+ scores = scores.reshape(-1, task.height, task.width)
+ mazes = mazes.reshape(-1, task.height, task.width)
+ targets = targets.reshape(-1, task.height, task.width)
+ maze.save_image(
+ os.path.join(
+ args.result_dir,
+ f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png",
+ ),
+ mazes=mazes,
+ score_paths=scores,
+ score_truth=targets,