nn.Linear(args.dim_model, args.dim_model),
nn.ReLU(),
nn.Linear(args.dim_model, 4),
).to(device)
nn.Linear(args.dim_model, args.dim_model),
nn.ReLU(),
nn.Linear(args.dim_model, 4),
).to(device)
learning_rate = learning_rate_schedule[n_epoch]
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
learning_rate = learning_rate_schedule[n_epoch]
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for input, targets in task.policy_batches(split="train"):
output_gpt = gpt(mygpt.BracketedSequence(input), with_readout=False).x
output = model(output_gpt)
for input, targets in task.policy_batches(split="train"):
output_gpt = gpt(mygpt.BracketedSequence(input), with_readout=False).x
output = model(output_gpt)
- -(output.log_softmax(-1) * targets).sum(-1).mean()
- + targets.xlogy(targets).sum(-1).mean()
+ -(output.log_softmax(-1) * targets).sum()
+ / (input == maze.v_empty).sum()
+ + targets.xlogy(targets).sum() / (input == maze.v_empty).sum()
for input, targets in task.policy_batches(split="test"):
output_gpt = gpt(mygpt.BracketedSequence(input), with_readout=False).x
output = model(output_gpt)
for input, targets in task.policy_batches(split="test"):
output_gpt = gpt(mygpt.BracketedSequence(input), with_readout=False).x
output = model(output_gpt)
- -(output.log_softmax(-1) * targets).sum(-1).mean()
- + targets.xlogy(targets).sum(-1).mean()
+ -(output.log_softmax(-1) * targets).sum()
+ / (input == maze.v_empty).sum()
+ + targets.xlogy(targets).sum() / (input == maze.v_empty).sum()
output_gpt = gpt(mygpt.BracketedSequence(input), with_readout=False).x
output = model(output_gpt)
losses = (-output.log_softmax(-1) * targets + targets.xlogy(targets)).sum(-1)
output_gpt = gpt(mygpt.BracketedSequence(input), with_readout=False).x
output = model(output_gpt)
losses = (-output.log_softmax(-1) * targets + targets.xlogy(targets)).sum(-1)
losses = losses.reshape(-1, args.maze_height, args.maze_width)
input = input.reshape(-1, args.maze_height, args.maze_width)
maze.save_image(
losses = losses.reshape(-1, args.maze_height, args.maze_width)
input = input.reshape(-1, args.maze_height, args.maze_width)
maze.save_image(