From 27bb2d1ab23422f26b05f88b4e0573deeb075cd2 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 12 Mar 2023 10:59:32 +0100 Subject: [PATCH] Update --- beaver.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/beaver.py b/beaver.py index 7adb804..517f29a 100755 --- a/beaver.py +++ b/beaver.py @@ -129,9 +129,8 @@ def masked_inplace_autoregression(model, batch_size, input, ar_mask): for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)): i = (ar_mask.sum(0) > 0).nonzero() if i.min() > 0: - model( - mygpt.BracketedSequence(input, 0, i.min()) - ) # Needed to initialize the model's cache + # Needed to initialize the model's cache + model(mygpt.BracketedSequence(input, 0, i.min())) for s in range(i.min(), i.max() + 1): output = model(mygpt.BracketedSequence(input, s, 1)).x logits = output[:, s] @@ -419,9 +418,6 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): for input in task.batches(split="test"): input = input.to(device) - # input, loss_masks, true_images = task.excise_last_image(input) - # input, loss_masks = task.add_true_image(input, true_images, loss_masks) - output = model(mygpt.BracketedSequence(input)).x loss = F.cross_entropy(output.transpose(1, 2), input) acc_test_loss += loss.item() * input.size(0) -- 2.20.1