projects
/
beaver.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
c4eb660
)
Update
author
François Fleuret
<francois@fleuret.org>
Sun, 12 Mar 2023 09:59:32 +0000
(10:59 +0100)
committer
François Fleuret
<francois@fleuret.org>
Sun, 12 Mar 2023 09:59:32 +0000
(10:59 +0100)
beaver.py
patch
|
blob
|
history
diff --git
a/beaver.py
b/beaver.py
index
7adb804
..
517f29a
100755
(executable)
--- a/
beaver.py
+++ b/
beaver.py
@@
-129,9
+129,8
@@
def masked_inplace_autoregression(model, batch_size, input, ar_mask):
for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)):
i = (ar_mask.sum(0) > 0).nonzero()
if i.min() > 0:
for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)):
i = (ar_mask.sum(0) > 0).nonzero()
if i.min() > 0:
- model(
- mygpt.BracketedSequence(input, 0, i.min())
- ) # Needed to initialize the model's cache
+ # Needed to initialize the model's cache
+ model(mygpt.BracketedSequence(input, 0, i.min()))
for s in range(i.min(), i.max() + 1):
output = model(mygpt.BracketedSequence(input, s, 1)).x
logits = output[:, s]
for s in range(i.min(), i.max() + 1):
output = model(mygpt.BracketedSequence(input, s, 1)).x
logits = output[:, s]
@@
-419,9
+418,6
@@
for n_epoch in range(nb_epochs_finished, nb_epochs):
for input in task.batches(split="test"):
input = input.to(device)
for input in task.batches(split="test"):
input = input.to(device)
- # input, loss_masks, true_images = task.excise_last_image(input)
- # input, loss_masks = task.add_true_image(input, true_images, loss_masks)
-
output = model(mygpt.BracketedSequence(input)).x
loss = F.cross_entropy(output.transpose(1, 2), input)
acc_test_loss += loss.item() * input.size(0)
output = model(mygpt.BracketedSequence(input)).x
loss = F.cross_entropy(output.transpose(1, 2), input)
acc_test_loss += loss.item() * input.size(0)