From: François Fleuret Date: Tue, 14 Mar 2023 09:02:02 +0000 (+0100) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=beaver.git;a=commitdiff_plain;h=1e6f089e67087e8cf1bcb6865e8d405b0a50f372 Update. --- diff --git a/beaver.py b/beaver.py index d86ef1f..1408f0b 100755 --- a/beaver.py +++ b/beaver.py @@ -172,10 +172,15 @@ def compute_perplexity(model, split="train"): def one_shot(gpt, task): t = gpt.training gpt.eval() - model = nn.Linear(args.dim_model, 4).to(device) + model = nn.Sequential( + nn.Linear(args.dim_model, args.dim_model), + nn.ReLU(), + nn.Linear(args.dim_model, 4) + ).to(device) for n_epoch in range(args.nb_epochs): - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + learning_rate = learning_rate_schedule[n_epoch] + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) acc_train_loss, nb_train_samples = 0, 0 for input, targets in task.policy_batches(split="train"):