X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=d0fda7e4182878043e74a260f0676654fc12193f;hb=HEAD;hp=0cf70e0f674317b0c5c4884d248eb55a18ef6232;hpb=8492656cf0cc5de4f7e2c4aa8ccb717193293b40;p=culture.git diff --git a/mygpt.py b/mygpt.py index 0cf70e0..041d28c 100755 --- a/mygpt.py +++ b/mygpt.py @@ -19,6 +19,45 @@ from torch.nn import functional as F ###################################################################### + +class BSQ(nn.Module): + def __init__(self, L): + super().__init__() + self.L = L + + def forward(self, input, indexes=False): + norm = input.pow(2).sum(dim=2, keepdim=True).sqrt() + u = input / norm + + if indexes: + return ((u >= 0).long() * (2 ** torch.arange(self.L))[None, :]).sum(dim=1) + + hat_u = 1 / math.sqrt(self.L) * (2 * (u >= 0).float() - 1) + if self.training: + self.loss += u.mean(dim=0).tanh().pow(2).mean() + return hat_u + u - u.detach() + else: + return hat_u + + +class RandomBypass(nn.Module): + def __init__(self, m, p): + super().__init__() + self.m = m + self.p = p + + def forward(self, x): + y = self.m(x) + + if self.training: + u = (torch.rand(x.size(0), device=x.device) <= self.p).long()[:, None] + return (u * x.flatten(1) + (1 - u) * y.flatten(1)).reshape(x.size()) + else: + return y + + +###################################################################### + # A BracketedSequence is a BxTx... tensor with a first and a nb time # steps to compute. @@ -114,6 +153,30 @@ class AddPositionalEncoding(nn.Module): ############################## +class EncoderHead(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.fc = nn.Linear(dim_in, dim_out) + + def forward(self, bs): + z = self.fc(bs.x).mean(dim=1) + return z, bs.x.shape + + +class DecoderBottom(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.fc = nn.Linear(dim_in, dim_out) + + def forward(self, z_shape): + z, shape = z_shape + y = self.fc(z)[:, None, :].expand(shape) + return BracketedSequence(y) + + +############################## + + class QKVAttention(nn.Module): def __init__( self, @@ -121,7 +184,7 @@ class QKVAttention(nn.Module): dim_qk, dim_v, nb_heads=1, - causal=False, + compute_attzero=None, attention_dropout=0.0, ): super().__init__() @@ -129,7 +192,7 @@ class QKVAttention(nn.Module): def randw(*d): return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) - self.causal = causal + self.compute_attzero = compute_attzero self.attention_dropout = attention_dropout self.record_attention = False @@ -141,10 +204,6 @@ class QKVAttention(nn.Module): def forward(self, bs_q): x_q = bs_q.x - assert ( - self.causal or bs_q.complete() - ), "Partial evaluation is only possible for causal models" - if bs_q.first == 0: self.cache_k = x_q.new_zeros( x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1) @@ -169,12 +228,12 @@ class QKVAttention(nn.Module): "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb] ) / math.sqrt(self.w_q.size(1)) - if self.causal: + if self.compute_attzero is not None: if bs_q.first == 0: - self.cache_attzero = ( - torch.arange(x_q.size(1), device=q.device)[None, None, :, None] - < torch.arange(x_q.size(1), device=q.device)[None, None, None, :] - ) + self.cache_attzero = self.compute_attzero( + torch.arange(x_q.size(1), device=q.device)[:, None], + torch.arange(x_q.size(1), device=q.device)[None, :], + )[None, None, :, :] a = a.masked_fill( self.cache_attzero[ :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb @@ -201,6 +260,23 @@ class QKVAttention(nn.Module): ############################## +class NoiseInjector(nn.Module): + def __init__(self, identifier=None): + super().__init__() + self.noise_std = 0.0 + self.identifier = identifier + + def forward(self, x): + if self.noise_std > 0: + x = x * ( + 1 - 2 * (torch.rand(x.size(), device=x.device) < self.noise_std).long() + ) + return x + + +############################## + + class MyGPT(nn.Module): def __init__( self, @@ -210,7 +286,8 @@ class MyGPT(nn.Module): dim_hidden, nb_heads, nb_blocks, - causal=False, + compute_attzero=None, + autoencoder_dim=-1, dropout=0.0, len_max=1e5, ): @@ -218,6 +295,8 @@ class MyGPT(nn.Module): assert dim_model % nb_heads == 0 + self.temperature = 1.0 + self.embedding = nn.Sequential( CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)), AddPositionalEncoding(len_max), @@ -228,19 +307,23 @@ class MyGPT(nn.Module): for b in range(nb_blocks): trunk_blocks += [ WithResidual( - CacheWrapper(nn.LayerNorm((dim_model,))), + CacheWrapper( + nn.LayerNorm((dim_model,)), + NoiseInjector(identifier=("attention", b)), + ), QKVAttention( dim_in=dim_model, dim_qk=dim_keys, dim_v=dim_model // nb_heads, nb_heads=nb_heads, - causal=causal, + compute_attzero=compute_attzero, attention_dropout=dropout, ), ), WithResidual( CacheWrapper( nn.LayerNorm((dim_model,)), + NoiseInjector(identifier=("ffw", b)), nn.Linear(in_features=dim_model, out_features=dim_hidden), nn.ReLU(), nn.Linear(in_features=dim_hidden, out_features=dim_model), @@ -255,6 +338,26 @@ class MyGPT(nn.Module): nn.Linear(in_features=dim_model, out_features=vocabulary_size) ) + # ------------------------------------------------------- + if autoencoder_dim > 0: + self.encoder = nn.Sequential( + *( + trunk_blocks[: nb_blocks // 2] + + [EncoderHead(dim_model, autoencoder_dim)] + ) + ) + + self.decoder = nn.Sequential( + *( + [ + DecoderBottom(autoencoder_dim, dim_model), + AddPositionalEncoding(len_max), + ] + + trunk_blocks[nb_blocks // 2 :] + ) + ) + # ------------------------------------------------------- + with torch.no_grad(): for m in self.modules(): if isinstance(m, nn.Embedding): @@ -264,35 +367,58 @@ class MyGPT(nn.Module): m.weight.fill_(1.0) def forward(self, bs): + for m in self.modules(): + m.loss = 0 + bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) bs = self.embedding(bs) bs = self.trunk(bs) bs = self.readout(bs) + bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature + + for m in self.modules(): + self.loss += m.loss + return bs - # ar_mask is a tensor with 0s and 1s, of same shape as input, with - # 1s where tokens should be generated. The others are kept - # unchanged. + def encode(self, bs): + bs = self.embedding(bs) + z = self.encoder(bs) + return z - def masked_inplace_autoregression( - self, input, ar_mask, forbidden_tokens=None, deterministic_synthesis=False - ): - to_generate = (ar_mask.sum(0) > 0).nonzero() - if to_generate.min() > 0: - self( - BracketedSequence(input, 0, to_generate.min()) - ) # Needed to initialize the model's cache - for s in range(to_generate.min(), to_generate.max() + 1): - output = self(BracketedSequence(input, s, 1)).x - logits = output[:, s] - if forbidden_tokens is not None: - logits = logits.masked_fill(forbidden_tokens, float("-inf")) - if deterministic_synthesis: - t_next = logits.argmax(1) + def decode(self, z_shape): + bs = self.decoder(z_shape) + bs = self.readout(bs) + return bs + + def partial_forward(self, bs, start_layer=None, end_layer=None): + if start_layer is None: + # print(f"GENERATE {bs.first} {bs.first+bs.nb}") + bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) + bs = self.embedding(bs) + if end_layer is not None: + return self.trunk[:end_layer](bs) else: - dist = torch.distributions.categorical.Categorical(logits=logits) - t_next = dist.sample() - input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] + bs = self.trunk(bs) + bs = self.readout(bs) + return bs + else: + bs = self.trunk[start_layer:](bs) + bs = self.trunk(bs) + bs = self.readout(bs) + return bs + + def reset_transformations(self): + self.temperature = 1.0 + for m in self.modules(): + if isinstance(m, NoiseInjector): + m.noise_std = 0.0 + + def set_noise_injection(self, noise_std, identifier=None): + for m in self.modules(): + if isinstance(m, NoiseInjector): + if identifier is None or identifier == m.identifier: + m.noise_std = noise_std def record_attention(self, v=True): for m in self.modules(): @@ -323,7 +449,6 @@ if __name__ == "__main__": nb_heads=2, nb_blocks=2, dropout=0.1, - causal=True, ) model.eval()