From 1e6f089e67087e8cf1bcb6865e8d405b0a50f372 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 14 Mar 2023 10:02:02 +0100 Subject: [PATCH] Update. --- beaver.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/beaver.py b/beaver.py index d86ef1f..1408f0b 100755 --- a/beaver.py +++ b/beaver.py @@ -172,10 +172,15 @@ def compute_perplexity(model, split="train"): def one_shot(gpt, task): t = gpt.training gpt.eval() - model = nn.Linear(args.dim_model, 4).to(device) + model = nn.Sequential( + nn.Linear(args.dim_model, args.dim_model), + nn.ReLU(), + nn.Linear(args.dim_model, 4) + ).to(device) for n_epoch in range(args.nb_epochs): - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + learning_rate = learning_rate_schedule[n_epoch] + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) acc_train_loss, nb_train_samples = 0, 0 for input, targets in task.policy_batches(split="train"): -- 2.20.1