######################################################################
+# ra_mask is boolean, with 1s on the values to generate
+
+
def masked_inplace_autoregression(
- model, batch_size, input, ar_mask, forbidden_tokens=None, device=torch.device("cpu")
+ model,
+ batch_size,
+ input,
+ ar_mask,
+ forbidden_tokens=None,
+ progress_bar_desc="autoregression",
+ device=torch.device("cpu"),
):
- for input, ar_mask in tqdm.tqdm(
- zip(input.split(batch_size), ar_mask.split(batch_size)),
- dynamic_ncols=True,
- desc="autoregression",
- total=input.size(0) // batch_size,
- ):
+ batches = zip(input.split(batch_size), ar_mask.split(batch_size))
+ if progress_bar_desc is not None:
+ tqdm.tqdm(
+ batches,
+ dynamic_ncols=True,
+ desc=progress_bar_desc,
+ total=input.size(0) // batch_size,
+ )
+ for input, ar_mask in batches:
i = (ar_mask.sum(0) > 0).nonzero()
if i.min() > 0:
model(
input,
ar_masks,
forbidden_tokens,
+ progress_bar_desc=None,
device=self.device,
)
model.train(t)
for input in task.batches(split="test"):
input = input.to(device)
- # input, loss_masks, true_images = task.excise_last_image(input)
- # input, loss_masks = task.add_true_image(input, true_images, loss_masks)
-
output = model(mygpt.BracketedSequence(input)).x
loss = F.cross_entropy(output.transpose(1, 2), input)
acc_test_loss += loss.item() * input.size(0)
# Written by Francois Fleuret <francois@fleuret.org>
+# This is an implementation from scratch of a "GPT", that is a model
+# composed of several causal self-attention blocks. It is equipped
+# with a caching mechanism for keys and values to avoid a O(N^3) cost
+# for auto-regression.
+
import math
import torch
######################################################################
-
-class WithResidual(nn.Module):
- def __init__(self, *f):
- super().__init__()
- self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
-
- def forward(self, bs):
- bs.x = bs.x + self.f(bs).x
- return bs
-
-
-######################################################################
-
# A BracketedSequence is a BxTx... tensor with a first and a nb time
# steps to compute.
##############################
+class WithResidual(nn.Module):
+ def __init__(self, *f):
+ super().__init__()
+ self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+ def forward(self, bs):
+ bs.x = bs.x + self.f(bs).x
+ return bs
+
+
+##############################
+
+
class AddPositionalEncoding(nn.Module):
def __init__(self, len_max):
super().__init__()