Update
authorFrançois Fleuret <francois@fleuret.org>
Sun, 12 Mar 2023 09:59:32 +0000 (10:59 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 12 Mar 2023 09:59:32 +0000 (10:59 +0100)
beaver.py

index 7adb804..517f29a 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -129,9 +129,8 @@ def masked_inplace_autoregression(model, batch_size, input, ar_mask):
     for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)):
         i = (ar_mask.sum(0) > 0).nonzero()
         if i.min() > 0:
-            model(
-                mygpt.BracketedSequence(input, 0, i.min())
-            )  # Needed to initialize the model's cache
+            # Needed to initialize the model's cache
+            model(mygpt.BracketedSequence(input, 0, i.min()))
         for s in range(i.min(), i.max() + 1):
             output = model(mygpt.BracketedSequence(input, s, 1)).x
             logits = output[:, s]
@@ -419,9 +418,6 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
         for input in task.batches(split="test"):
             input = input.to(device)
 
-            # input, loss_masks, true_images = task.excise_last_image(input)
-            # input, loss_masks = task.add_true_image(input, true_images, loss_masks)
-
             output = model(mygpt.BracketedSequence(input)).x
             loss = F.cross_entropy(output.transpose(1, 2), input)
             acc_test_loss += loss.item() * input.size(0)