X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=c93010a2f9869389ad6e0685da65a83f46cd9e8b;hb=07ce0d849569e234f2d7714d7438dfab29542610;hp=5ea4668203f0f4eaf9ccbc586f2d58a348fe0a3f;hpb=0633abb4fa80284003f1d6d1b5b9e758c35ed10f;p=picoclvr.git diff --git a/mygpt.py b/mygpt.py index 5ea4668..c93010a 100755 --- a/mygpt.py +++ b/mygpt.py @@ -5,6 +5,11 @@ # Written by Francois Fleuret +# This is an implementation from scratch of a "GPT", that is a model +# composed of several causal self-attention blocks. It is equipped +# with a caching mechanism for keys and values to avoid a O(N^3) cost +# for auto-regression. + import math import torch @@ -14,19 +19,6 @@ from torch.nn import functional as F ###################################################################### - -class WithResidual(nn.Module): - def __init__(self, *f): - super().__init__() - self.f = f[0] if len(f) == 1 else nn.Sequential(*f) - - def forward(self, bs): - bs.x = bs.x + self.f(bs).x - return bs - - -###################################################################### - # A BracketedSequence is a BxTx... tensor with a first and a nb time # steps to compute. @@ -78,6 +70,19 @@ class CacheWrapper(nn.Module): ############################## +class WithResidual(nn.Module): + def __init__(self, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + + def forward(self, bs): + bs.x = bs.x + self.f(bs).x + return bs + + +############################## + + class AddPositionalEncoding(nn.Module): def __init__(self, len_max): super().__init__() @@ -197,7 +202,6 @@ class MyGPT(nn.Module): dropout=0.0, len_max=1e5, ): - super().__init__() assert dim_model % nb_heads == 0 @@ -254,11 +258,34 @@ class MyGPT(nn.Module): bs = self.readout(bs) return bs + # ar_mask is a tensor with 0s and 1s, of same shape as input, with + # 1s where tokens should be generated. The others are kept + # unchanged. + + def masked_inplace_autoregression( + self, input, ar_mask, forbidden_tokens=None, deterministic_synthesis=False + ): + to_generate = (ar_mask.sum(0) > 0).nonzero() + if to_generate.min() > 0: + self( + BracketedSequence(input, 0, to_generate.min()) + ) # Needed to initialize the model's cache + for s in range(to_generate.min(), to_generate.max() + 1): + output = self(BracketedSequence(input, s, 1)).x + logits = output[:, s] + if forbidden_tokens is not None: + logits = logits.masked_fill(forbidden_tokens, float("-inf")) + if deterministic_synthesis: + t_next = logits.argmax(1) + else: + dist = torch.distributions.categorical.Categorical(logits=logits) + t_next = dist.sample() + input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] + ###################################################################### if __name__ == "__main__": - print("Basic check.") vocabulary_size = 10 @@ -283,8 +310,6 @@ if __name__ == "__main__": z = model(BracketedSequence(x, s, 1)) y2[:, s] = z.x[:, s] - # print(y1.max(dim = 2).values) - # print(y2.max(dim = 2).values) print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}") ######################################################################