def oneshot(model, learning_rate_scheduler, task):
t = model.training
model.eval()
- mazes = task.test_input[:32].clone()
+ mazes = task.test_input[:48].clone()
mazes[:, task.height * task.width :] = 0
- policies = task.test_policies[:32]
+ policies = task.test_policies[:48]
targets = maze.stationary_densities(
mazes[:, : task.height * task.width].view(-1, task.height, task.width),
policies.view(-1, 4, task.height, task.width),
)
mazes = mazes[:, : task.height * task.width].reshape(-1, task.height, task.width)
targets = targets.reshape(-1, task.height, task.width)
- paths = task.test_input[:32, task.height * task.width :].reshape(
+ paths = task.test_input[:48, task.height * task.width :].reshape(
-1, task.height, task.width
)
filename = f"oneshot.png"
)
# -------------------
- mazes = task.test_input[:32, : task.height * task.width]
- policies = task.test_policies[:32]
+ mazes = task.test_input[:48, : task.height * task.width]
+ policies = task.test_policies[:48]
output_gpt = eval_mygpt(
gpt, mazes, mode=args.oneshot_input, prompt_len=task.height * task.width
)
f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
)
- input = self.test_input[:32]
+ input = self.test_input[:48]
result = input.clone()
ar_mask = result.new_zeros(result.size())
ar_mask[:, self.height * self.width :] = 1