def one_shot(gpt, task):
t = gpt.training
gpt.eval()
- model = nn.Linear(args.dim_model, 4).to(device)
+ model = nn.Sequential(
+ nn.Linear(args.dim_model, args.dim_model),
+ nn.ReLU(),
+ nn.Linear(args.dim_model, 4)
+ ).to(device)
for n_epoch in range(args.nb_epochs):
- optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
+ learning_rate = learning_rate_schedule[n_epoch]
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
acc_train_loss, nb_train_samples = 0, 0
for input, targets in task.policy_batches(split="train"):