X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=d0fda7e4182878043e74a260f0676654fc12193f;hb=HEAD;hp=5ea4668203f0f4eaf9ccbc586f2d58a348fe0a3f;hpb=0633abb4fa80284003f1d6d1b5b9e758c35ed10f;p=culture.git diff --git a/mygpt.py b/mygpt.py index 5ea4668..041d28c 100755 --- a/mygpt.py +++ b/mygpt.py @@ -5,6 +5,11 @@ # Written by Francois Fleuret +# This is an implementation from scratch of a "GPT", that is a model +# composed of several causal self-attention blocks. It is equipped +# with a caching mechanism for keys and values to avoid a O(N^3) cost +# for auto-regression. + import math import torch @@ -15,14 +20,40 @@ from torch.nn import functional as F ###################################################################### -class WithResidual(nn.Module): - def __init__(self, *f): +class BSQ(nn.Module): + def __init__(self, L): super().__init__() - self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + self.L = L - def forward(self, bs): - bs.x = bs.x + self.f(bs).x - return bs + def forward(self, input, indexes=False): + norm = input.pow(2).sum(dim=2, keepdim=True).sqrt() + u = input / norm + + if indexes: + return ((u >= 0).long() * (2 ** torch.arange(self.L))[None, :]).sum(dim=1) + + hat_u = 1 / math.sqrt(self.L) * (2 * (u >= 0).float() - 1) + if self.training: + self.loss += u.mean(dim=0).tanh().pow(2).mean() + return hat_u + u - u.detach() + else: + return hat_u + + +class RandomBypass(nn.Module): + def __init__(self, m, p): + super().__init__() + self.m = m + self.p = p + + def forward(self, x): + y = self.m(x) + + if self.training: + u = (torch.rand(x.size(0), device=x.device) <= self.p).long()[:, None] + return (u * x.flatten(1) + (1 - u) * y.flatten(1)).reshape(x.size()) + else: + return y ###################################################################### @@ -53,6 +84,9 @@ class BracketedSequence: def slice(self): return self.x[:, self.first : self.first + self.nb] + def complete(self): + return self.first == 0 and self.nb == self.x.size(1) + ###################################################################### @@ -70,9 +104,19 @@ class CacheWrapper(nn.Module): else: self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice()) - bs.x = self.cache_y + return BracketedSequence(self.cache_y, bs.first, bs.nb) - return bs + +############################## + + +class WithResidual(nn.Module): + def __init__(self, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + + def forward(self, bs): + return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb) ############################## @@ -103,9 +147,31 @@ class AddPositionalEncoding(nn.Module): bs.slice() + self.pe[bs.first : bs.first + bs.nb] ) - bs.x = self.cache_y + return BracketedSequence(self.cache_y, bs.first, bs.nb) - return bs + +############################## + + +class EncoderHead(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.fc = nn.Linear(dim_in, dim_out) + + def forward(self, bs): + z = self.fc(bs.x).mean(dim=1) + return z, bs.x.shape + + +class DecoderBottom(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.fc = nn.Linear(dim_in, dim_out) + + def forward(self, z_shape): + z, shape = z_shape + y = self.fc(z)[:, None, :].expand(shape) + return BracketedSequence(y) ############################## @@ -113,15 +179,22 @@ class AddPositionalEncoding(nn.Module): class QKVAttention(nn.Module): def __init__( - self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0 + self, + dim_in, + dim_qk, + dim_v, + nb_heads=1, + compute_attzero=None, + attention_dropout=0.0, ): super().__init__() def randw(*d): return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) - self.causal = causal + self.compute_attzero = compute_attzero self.attention_dropout = attention_dropout + self.record_attention = False self.w_q = randw(nb_heads, dim_qk, dim_in) self.w_k = randw(nb_heads, dim_qk, dim_in) @@ -143,6 +216,7 @@ class QKVAttention(nn.Module): q = torch.einsum( "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q ) + self.cache_k[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum( "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_k ) @@ -154,12 +228,12 @@ class QKVAttention(nn.Module): "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb] ) / math.sqrt(self.w_q.size(1)) - if self.causal: + if self.compute_attzero is not None: if bs_q.first == 0: - self.cache_attzero = ( - torch.arange(x_q.size(1), device=q.device)[None, None, :, None] - < torch.arange(x_q.size(1), device=q.device)[None, None, None, :] - ) + self.cache_attzero = self.compute_attzero( + torch.arange(x_q.size(1), device=q.device)[:, None], + torch.arange(x_q.size(1), device=q.device)[None, :], + )[None, None, :, :] a = a.masked_fill( self.cache_attzero[ :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb @@ -168,6 +242,10 @@ class QKVAttention(nn.Module): ) a = a.softmax(dim=3) + + if self.record_attention: + self.a = a + a = F.dropout(a, self.attention_dropout, self.training) y = torch.einsum( @@ -176,9 +254,24 @@ class QKVAttention(nn.Module): self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o - bs_q.x = self.cache_y + return BracketedSequence(self.cache_y, bs_q.first, bs_q.nb) + - return bs_q +############################## + + +class NoiseInjector(nn.Module): + def __init__(self, identifier=None): + super().__init__() + self.noise_std = 0.0 + self.identifier = identifier + + def forward(self, x): + if self.noise_std > 0: + x = x * ( + 1 - 2 * (torch.rand(x.size(), device=x.device) < self.noise_std).long() + ) + return x ############################## @@ -193,15 +286,17 @@ class MyGPT(nn.Module): dim_hidden, nb_heads, nb_blocks, - causal=False, + compute_attzero=None, + autoencoder_dim=-1, dropout=0.0, len_max=1e5, ): - super().__init__() assert dim_model % nb_heads == 0 + self.temperature = 1.0 + self.embedding = nn.Sequential( CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)), AddPositionalEncoding(len_max), @@ -212,19 +307,23 @@ class MyGPT(nn.Module): for b in range(nb_blocks): trunk_blocks += [ WithResidual( - CacheWrapper(nn.LayerNorm((dim_model,))), + CacheWrapper( + nn.LayerNorm((dim_model,)), + NoiseInjector(identifier=("attention", b)), + ), QKVAttention( dim_in=dim_model, dim_qk=dim_keys, dim_v=dim_model // nb_heads, nb_heads=nb_heads, - causal=causal, + compute_attzero=compute_attzero, attention_dropout=dropout, ), ), WithResidual( CacheWrapper( nn.LayerNorm((dim_model,)), + NoiseInjector(identifier=("ffw", b)), nn.Linear(in_features=dim_model, out_features=dim_hidden), nn.ReLU(), nn.Linear(in_features=dim_hidden, out_features=dim_model), @@ -239,6 +338,26 @@ class MyGPT(nn.Module): nn.Linear(in_features=dim_model, out_features=vocabulary_size) ) + # ------------------------------------------------------- + if autoencoder_dim > 0: + self.encoder = nn.Sequential( + *( + trunk_blocks[: nb_blocks // 2] + + [EncoderHead(dim_model, autoencoder_dim)] + ) + ) + + self.decoder = nn.Sequential( + *( + [ + DecoderBottom(autoencoder_dim, dim_model), + AddPositionalEncoding(len_max), + ] + + trunk_blocks[nb_blocks // 2 :] + ) + ) + # ------------------------------------------------------- + with torch.no_grad(): for m in self.modules(): if isinstance(m, nn.Embedding): @@ -248,43 +367,97 @@ class MyGPT(nn.Module): m.weight.fill_(1.0) def forward(self, bs): - bs.x = F.pad(bs.x, (1, -1)) + for m in self.modules(): + m.loss = 0 + + bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) bs = self.embedding(bs) bs = self.trunk(bs) bs = self.readout(bs) + bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature + + for m in self.modules(): + self.loss += m.loss + + return bs + + def encode(self, bs): + bs = self.embedding(bs) + z = self.encoder(bs) + return z + + def decode(self, z_shape): + bs = self.decoder(z_shape) + bs = self.readout(bs) return bs + def partial_forward(self, bs, start_layer=None, end_layer=None): + if start_layer is None: + # print(f"GENERATE {bs.first} {bs.first+bs.nb}") + bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) + bs = self.embedding(bs) + if end_layer is not None: + return self.trunk[:end_layer](bs) + else: + bs = self.trunk(bs) + bs = self.readout(bs) + return bs + else: + bs = self.trunk[start_layer:](bs) + bs = self.trunk(bs) + bs = self.readout(bs) + return bs + + def reset_transformations(self): + self.temperature = 1.0 + for m in self.modules(): + if isinstance(m, NoiseInjector): + m.noise_std = 0.0 + + def set_noise_injection(self, noise_std, identifier=None): + for m in self.modules(): + if isinstance(m, NoiseInjector): + if identifier is None or identifier == m.identifier: + m.noise_std = noise_std + + def record_attention(self, v=True): + for m in self.modules(): + if isinstance(m, QKVAttention): + m.record_attention = v + + def retrieve_attention(self): + a = [] + for m in self.modules(): + if isinstance(m, QKVAttention): + a.append(m.a) + return a + ###################################################################### if __name__ == "__main__": - print("Basic check.") - vocabulary_size = 10 - x = torch.randint(vocabulary_size, (9, 7)) + vocabulary_size = 3 + x = torch.randint(vocabulary_size, (1, 5)) model = MyGPT( vocabulary_size=vocabulary_size, - dim_model=18, - dim_keys=50, - dim_hidden=100, + dim_model=4, + dim_keys=2, + dim_hidden=2, nb_heads=2, - nb_blocks=1, + nb_blocks=2, dropout=0.1, ) model.eval() - y1 = model(BracketedSequence(x)).x - y2 = torch.randn_like(y1) for s in range(x.size(1)): z = model(BracketedSequence(x, s, 1)) - y2[:, s] = z.x[:, s] + y2[:, s] = z.slice() - # print(y1.max(dim = 2).values) - # print(y2.max(dim = 2).values) print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}") ######################################################################