+def oneshot_policy_loss(mazes, output, policies, height, width):
+ masks = (mazes == maze.v_empty).unsqueeze(-1)
+ targets = policies.permute(0, 2, 1) * masks
+ output = output * masks
+ return -(output.log_softmax(-1) * targets).sum() / masks.sum()
+
+
+def oneshot_trace_loss(mazes, output, policies, height, width):
+ masks = mazes == maze.v_empty
+ targets = maze.stationary_densities(
+ mazes.view(-1, height, width), policies.view(-1, 4, height, width)
+ ).flatten(-2)
+ targets = targets * masks
+ output = output.squeeze(-1) * masks
+ return (output - targets).abs().sum() / masks.sum()
+
+
+def oneshot(model, learning_rate_scheduler, task):
+ t = model.training
+ model.eval()
+ mazes = task.test_input[:48].clone()
+ mazes[:, task.height * task.width :] = 0
+ policies = task.test_policies[:48]
+ targets = maze.stationary_densities(
+ mazes[:, : task.height * task.width].view(-1, task.height, task.width),
+ policies.view(-1, 4, task.height, task.width),
+ ).flatten(-2)
+ output = eval_mygpt(model, mazes, prompt_len=task.height * task.width)
+ output = F.softmax(output, dim=2)
+ print(f"{output.size()=}")
+ proba_path = output[:, task.height * task.width :, 4].reshape(
+ -1, task.height, task.width
+ )
+ mazes = mazes[:, : task.height * task.width].reshape(-1, task.height, task.width)
+ targets = targets.reshape(-1, task.height, task.width)
+ paths = task.test_input[:48, task.height * task.width :].reshape(
+ -1, task.height, task.width
+ )
+ filename = f"oneshot.png"
+ maze.save_image(
+ os.path.join(args.result_dir, filename),
+ mazes=mazes,
+ # target_paths=paths,
+ score_paths=proba_path,
+ score_truth=targets,
+ )
+ log_string(f"wrote {filename}")
+
+
+def oneshot_old(gpt, learning_rate_scheduler, task):