projects
/
beaver.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (from parent 1:
91ed776
)
Update
author
François Fleuret
<francois@fleuret.org>
Tue, 14 Mar 2023 07:38:44 +0000
(08:38 +0100)
committer
François Fleuret
<francois@fleuret.org>
Tue, 14 Mar 2023 07:38:44 +0000
(08:38 +0100)
beaver.py
patch
|
blob
|
history
diff --git
a/beaver.py
b/beaver.py
index
c29dea5
..
d86ef1f
100755
(executable)
--- a/
beaver.py
+++ b/
beaver.py
@@
-169,15
+169,6
@@
def compute_perplexity(model, split="train"):
######################################################################
######################################################################
-def nb_rank_error(output, targets):
- output = output.reshape(-1, output.size(-1))
- targets = targets.reshape(-1, targets.size(-1))
- i = outputs.argmax(1)
- # out=input.gather out[i][j]=input[i][index[i][j]]
- # u[k]=targets[k][i[k]]
- return output[targets.argmax(1)]
-
-
def one_shot(gpt, task):
t = gpt.training
gpt.eval()
def one_shot(gpt, task):
t = gpt.training
gpt.eval()
@@
-190,7
+181,10
@@
def one_shot(gpt, task):
for input, targets in task.policy_batches(split="train"):
output_gpt = gpt(mygpt.BracketedSequence(input), with_readout=False).x
output = model(output_gpt)
for input, targets in task.policy_batches(split="train"):
output_gpt = gpt(mygpt.BracketedSequence(input), with_readout=False).x
output = model(output_gpt)
- loss = -(output.log_softmax(-1) * targets).sum(-1).mean()
+ loss = (
+ -(output.log_softmax(-1) * targets).sum(-1).mean()
+ + targets.xlogy(targets).sum(-1).mean()
+ )
acc_train_loss += loss.item() * input.size(0)
nb_train_samples += input.size(0)
acc_train_loss += loss.item() * input.size(0)
nb_train_samples += input.size(0)
@@
-202,12
+196,15
@@
def one_shot(gpt, task):
for input, targets in task.policy_batches(split="test"):
output_gpt = gpt(mygpt.BracketedSequence(input), with_readout=False).x
output = model(output_gpt)
for input, targets in task.policy_batches(split="test"):
output_gpt = gpt(mygpt.BracketedSequence(input), with_readout=False).x
output = model(output_gpt)
- loss = -(output.log_softmax(-1) * targets).sum(-1).mean()
+ loss = (
+ -(output.log_softmax(-1) * targets).sum(-1).mean()
+ + targets.xlogy(targets).sum(-1).mean()
+ )
acc_test_loss += loss.item() * input.size(0)
nb_test_samples += input.size(0)
acc_test_loss += loss.item() * input.size(0)
nb_test_samples += input.size(0)
-
print
(
- f"
{n_epoch=} {acc_train_loss/nb_train_samples=} {acc_test_loss/nb_test_samples=
}"
+
log_string
(
+ f"
diff_ce {n_epoch} train {acc_train_loss/nb_train_samples} test {acc_test_loss/nb_test_samples
}"
)
gpt.train(t)
)
gpt.train(t)