+ # -------------------
+ input = task.test_input[:32, : task.height * task.width]
+ targets = task.test_policies[:32].permute(0, 2, 1)
+ output_gpt = gpt(mygpt.BracketedSequence(input), mode=args.oneshot_input).x
+ output = model(output_gpt)
+ scores = (
+ (F.one_hot(output.argmax(-1), num_classes=4) * targets).sum(-1) == 0
+ ).float()
+ scores = scores.reshape(-1, task.height, task.width)
+ input = input.reshape(-1, task.height, task.width)
+ maze.save_image(
+ os.path.join(
+ args.result_dir,
+ f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png",
+ ),
+ mazes=input,
+ score_paths=scores,
+ )
+ # -------------------
+