projects
/
beaver.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
7143cc5
)
Update.
author
François Fleuret
<francois@fleuret.org>
Tue, 14 Mar 2023 09:02:02 +0000
(10:02 +0100)
committer
François Fleuret
<francois@fleuret.org>
Tue, 14 Mar 2023 09:02:02 +0000
(10:02 +0100)
beaver.py
patch
|
blob
|
history
diff --git
a/beaver.py
b/beaver.py
index
d86ef1f
..
1408f0b
100755
(executable)
--- a/
beaver.py
+++ b/
beaver.py
@@
-172,10
+172,15
@@
def compute_perplexity(model, split="train"):
def one_shot(gpt, task):
t = gpt.training
gpt.eval()
def one_shot(gpt, task):
t = gpt.training
gpt.eval()
- model = nn.Linear(args.dim_model, 4).to(device)
+ model = nn.Sequential(
+ nn.Linear(args.dim_model, args.dim_model),
+ nn.ReLU(),
+ nn.Linear(args.dim_model, 4)
+ ).to(device)
for n_epoch in range(args.nb_epochs):
for n_epoch in range(args.nb_epochs):
- optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
+ learning_rate = learning_rate_schedule[n_epoch]
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
acc_train_loss, nb_train_samples = 0, 0
for input, targets in task.policy_batches(split="train"):
acc_train_loss, nb_train_samples = 0, 0
for input, targets in task.policy_batches(split="train"):