Merge branch 'dev'
[culture.git] / mygpt.py
index 5ea4668..041d28c 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -5,6 +5,11 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
+# This is an implementation from scratch of a "GPT", that is a model
+# composed of several causal self-attention blocks. It is equipped
+# with a caching mechanism for keys and values to avoid a O(N^3) cost
+# for auto-regression.
+
 import math
 
 import torch
 import math
 
 import torch
@@ -15,14 +20,40 @@ from torch.nn import functional as F
 ######################################################################
 
 
 ######################################################################
 
 
-class WithResidual(nn.Module):
-    def __init__(self, *f):
+class BSQ(nn.Module):
+    def __init__(self, L):
         super().__init__()
         super().__init__()
-        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+        self.L = L
 
 
-    def forward(self, bs):
-        bs.x = bs.x + self.f(bs).x
-        return bs
+    def forward(self, input, indexes=False):
+        norm = input.pow(2).sum(dim=2, keepdim=True).sqrt()
+        u = input / norm
+
+        if indexes:
+            return ((u >= 0).long() * (2 ** torch.arange(self.L))[None, :]).sum(dim=1)
+
+        hat_u = 1 / math.sqrt(self.L) * (2 * (u >= 0).float() - 1)
+        if self.training:
+            self.loss += u.mean(dim=0).tanh().pow(2).mean()
+            return hat_u + u - u.detach()
+        else:
+            return hat_u
+
+
+class RandomBypass(nn.Module):
+    def __init__(self, m, p):
+        super().__init__()
+        self.m = m
+        self.p = p
+
+    def forward(self, x):
+        y = self.m(x)
+
+        if self.training:
+            u = (torch.rand(x.size(0), device=x.device) <= self.p).long()[:, None]
+            return (u * x.flatten(1) + (1 - u) * y.flatten(1)).reshape(x.size())
+        else:
+            return y
 
 
 ######################################################################
 
 
 ######################################################################
@@ -53,6 +84,9 @@ class BracketedSequence:
     def slice(self):
         return self.x[:, self.first : self.first + self.nb]
 
     def slice(self):
         return self.x[:, self.first : self.first + self.nb]
 
+    def complete(self):
+        return self.first == 0 and self.nb == self.x.size(1)
+
 
 ######################################################################
 
 
 ######################################################################
 
@@ -70,9 +104,19 @@ class CacheWrapper(nn.Module):
         else:
             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
 
         else:
             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
 
-        bs.x = self.cache_y
+        return BracketedSequence(self.cache_y, bs.first, bs.nb)
 
 
-        return bs
+
+##############################
+
+
+class WithResidual(nn.Module):
+    def __init__(self, *f):
+        super().__init__()
+        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+    def forward(self, bs):
+        return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb)
 
 
 ##############################
 
 
 ##############################
@@ -103,9 +147,31 @@ class AddPositionalEncoding(nn.Module):
             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
         )
 
             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
         )
 
-        bs.x = self.cache_y
+        return BracketedSequence(self.cache_y, bs.first, bs.nb)
 
 
-        return bs
+
+##############################
+
+
+class EncoderHead(nn.Module):
+    def __init__(self, dim_in, dim_out):
+        super().__init__()
+        self.fc = nn.Linear(dim_in, dim_out)
+
+    def forward(self, bs):
+        z = self.fc(bs.x).mean(dim=1)
+        return z, bs.x.shape
+
+
+class DecoderBottom(nn.Module):
+    def __init__(self, dim_in, dim_out):
+        super().__init__()
+        self.fc = nn.Linear(dim_in, dim_out)
+
+    def forward(self, z_shape):
+        z, shape = z_shape
+        y = self.fc(z)[:, None, :].expand(shape)
+        return BracketedSequence(y)
 
 
 ##############################
 
 
 ##############################
@@ -113,15 +179,22 @@ class AddPositionalEncoding(nn.Module):
 
 class QKVAttention(nn.Module):
     def __init__(
 
 class QKVAttention(nn.Module):
     def __init__(
-        self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0
+        self,
+        dim_in,
+        dim_qk,
+        dim_v,
+        nb_heads=1,
+        compute_attzero=None,
+        attention_dropout=0.0,
     ):
         super().__init__()
 
         def randw(*d):
             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
 
     ):
         super().__init__()
 
         def randw(*d):
             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
 
-        self.causal = causal
+        self.compute_attzero = compute_attzero
         self.attention_dropout = attention_dropout
         self.attention_dropout = attention_dropout
+        self.record_attention = False
 
         self.w_q = randw(nb_heads, dim_qk, dim_in)
         self.w_k = randw(nb_heads, dim_qk, dim_in)
 
         self.w_q = randw(nb_heads, dim_qk, dim_in)
         self.w_k = randw(nb_heads, dim_qk, dim_in)
@@ -143,6 +216,7 @@ class QKVAttention(nn.Module):
         q = torch.einsum(
             "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q
         )
         q = torch.einsum(
             "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q
         )
+
         self.cache_k[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
             "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_k
         )
         self.cache_k[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
             "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_k
         )
@@ -154,12 +228,12 @@ class QKVAttention(nn.Module):
             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
         ) / math.sqrt(self.w_q.size(1))
 
             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
         ) / math.sqrt(self.w_q.size(1))
 
-        if self.causal:
+        if self.compute_attzero is not None:
             if bs_q.first == 0:
             if bs_q.first == 0:
-                self.cache_attzero = (
-                    torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
-                    < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
-                )
+                self.cache_attzero = self.compute_attzero(
+                    torch.arange(x_q.size(1), device=q.device)[:, None],
+                    torch.arange(x_q.size(1), device=q.device)[None, :],
+                )[None, None, :, :]
             a = a.masked_fill(
                 self.cache_attzero[
                     :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
             a = a.masked_fill(
                 self.cache_attzero[
                     :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
@@ -168,6 +242,10 @@ class QKVAttention(nn.Module):
             )
 
         a = a.softmax(dim=3)
             )
 
         a = a.softmax(dim=3)
+
+        if self.record_attention:
+            self.a = a
+
         a = F.dropout(a, self.attention_dropout, self.training)
 
         y = torch.einsum(
         a = F.dropout(a, self.attention_dropout, self.training)
 
         y = torch.einsum(
@@ -176,9 +254,24 @@ class QKVAttention(nn.Module):
 
         self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o
 
 
         self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o
 
-        bs_q.x = self.cache_y
+        return BracketedSequence(self.cache_y, bs_q.first, bs_q.nb)
+
 
 
-        return bs_q
+##############################
+
+
+class NoiseInjector(nn.Module):
+    def __init__(self, identifier=None):
+        super().__init__()
+        self.noise_std = 0.0
+        self.identifier = identifier
+
+    def forward(self, x):
+        if self.noise_std > 0:
+            x = x * (
+                1 - 2 * (torch.rand(x.size(), device=x.device) < self.noise_std).long()
+            )
+        return x
 
 
 ##############################
 
 
 ##############################
@@ -193,15 +286,17 @@ class MyGPT(nn.Module):
         dim_hidden,
         nb_heads,
         nb_blocks,
         dim_hidden,
         nb_heads,
         nb_blocks,
-        causal=False,
+        compute_attzero=None,
+        autoencoder_dim=-1,
         dropout=0.0,
         len_max=1e5,
     ):
         dropout=0.0,
         len_max=1e5,
     ):
-
         super().__init__()
 
         assert dim_model % nb_heads == 0
 
         super().__init__()
 
         assert dim_model % nb_heads == 0
 
+        self.temperature = 1.0
+
         self.embedding = nn.Sequential(
             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
             AddPositionalEncoding(len_max),
         self.embedding = nn.Sequential(
             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
             AddPositionalEncoding(len_max),
@@ -212,19 +307,23 @@ class MyGPT(nn.Module):
         for b in range(nb_blocks):
             trunk_blocks += [
                 WithResidual(
         for b in range(nb_blocks):
             trunk_blocks += [
                 WithResidual(
-                    CacheWrapper(nn.LayerNorm((dim_model,))),
+                    CacheWrapper(
+                        nn.LayerNorm((dim_model,)),
+                        NoiseInjector(identifier=("attention", b)),
+                    ),
                     QKVAttention(
                         dim_in=dim_model,
                         dim_qk=dim_keys,
                         dim_v=dim_model // nb_heads,
                         nb_heads=nb_heads,
                     QKVAttention(
                         dim_in=dim_model,
                         dim_qk=dim_keys,
                         dim_v=dim_model // nb_heads,
                         nb_heads=nb_heads,
-                        causal=causal,
+                        compute_attzero=compute_attzero,
                         attention_dropout=dropout,
                     ),
                 ),
                 WithResidual(
                     CacheWrapper(
                         nn.LayerNorm((dim_model,)),
                         attention_dropout=dropout,
                     ),
                 ),
                 WithResidual(
                     CacheWrapper(
                         nn.LayerNorm((dim_model,)),
+                        NoiseInjector(identifier=("ffw", b)),
                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
                         nn.ReLU(),
                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
                         nn.ReLU(),
                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
@@ -239,6 +338,26 @@ class MyGPT(nn.Module):
             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
         )
 
             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
         )
 
+        # -------------------------------------------------------
+        if autoencoder_dim > 0:
+            self.encoder = nn.Sequential(
+                *(
+                    trunk_blocks[: nb_blocks // 2]
+                    + [EncoderHead(dim_model, autoencoder_dim)]
+                )
+            )
+
+            self.decoder = nn.Sequential(
+                *(
+                    [
+                        DecoderBottom(autoencoder_dim, dim_model),
+                        AddPositionalEncoding(len_max),
+                    ]
+                    + trunk_blocks[nb_blocks // 2 :]
+                )
+            )
+        # -------------------------------------------------------
+
         with torch.no_grad():
             for m in self.modules():
                 if isinstance(m, nn.Embedding):
         with torch.no_grad():
             for m in self.modules():
                 if isinstance(m, nn.Embedding):
@@ -248,43 +367,97 @@ class MyGPT(nn.Module):
                     m.weight.fill_(1.0)
 
     def forward(self, bs):
                     m.weight.fill_(1.0)
 
     def forward(self, bs):
-        bs.x = F.pad(bs.x, (1, -1))
+        for m in self.modules():
+            m.loss = 0
+
+        bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
         bs = self.embedding(bs)
         bs = self.trunk(bs)
         bs = self.readout(bs)
         bs = self.embedding(bs)
         bs = self.trunk(bs)
         bs = self.readout(bs)
+        bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature
+
+        for m in self.modules():
+            self.loss += m.loss
+
+        return bs
+
+    def encode(self, bs):
+        bs = self.embedding(bs)
+        z = self.encoder(bs)
+        return z
+
+    def decode(self, z_shape):
+        bs = self.decoder(z_shape)
+        bs = self.readout(bs)
         return bs
 
         return bs
 
+    def partial_forward(self, bs, start_layer=None, end_layer=None):
+        if start_layer is None:
+            # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
+            bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
+            bs = self.embedding(bs)
+            if end_layer is not None:
+                return self.trunk[:end_layer](bs)
+            else:
+                bs = self.trunk(bs)
+                bs = self.readout(bs)
+                return bs
+        else:
+            bs = self.trunk[start_layer:](bs)
+            bs = self.trunk(bs)
+            bs = self.readout(bs)
+            return bs
+
+    def reset_transformations(self):
+        self.temperature = 1.0
+        for m in self.modules():
+            if isinstance(m, NoiseInjector):
+                m.noise_std = 0.0
+
+    def set_noise_injection(self, noise_std, identifier=None):
+        for m in self.modules():
+            if isinstance(m, NoiseInjector):
+                if identifier is None or identifier == m.identifier:
+                    m.noise_std = noise_std
+
+    def record_attention(self, v=True):
+        for m in self.modules():
+            if isinstance(m, QKVAttention):
+                m.record_attention = v
+
+    def retrieve_attention(self):
+        a = []
+        for m in self.modules():
+            if isinstance(m, QKVAttention):
+                a.append(m.a)
+        return a
+
 
 ######################################################################
 
 if __name__ == "__main__":
 
 ######################################################################
 
 if __name__ == "__main__":
-
     print("Basic check.")
 
     print("Basic check.")
 
-    vocabulary_size = 10
-    x = torch.randint(vocabulary_size, (9, 7))
+    vocabulary_size = 3
+    x = torch.randint(vocabulary_size, (1, 5))
 
     model = MyGPT(
         vocabulary_size=vocabulary_size,
 
     model = MyGPT(
         vocabulary_size=vocabulary_size,
-        dim_model=18,
-        dim_keys=50,
-        dim_hidden=100,
+        dim_model=4,
+        dim_keys=2,
+        dim_hidden=2,
         nb_heads=2,
         nb_heads=2,
-        nb_blocks=1,
+        nb_blocks=2,
         dropout=0.1,
     )
 
     model.eval()
         dropout=0.1,
     )
 
     model.eval()
-
     y1 = model(BracketedSequence(x)).x
     y1 = model(BracketedSequence(x)).x
-
     y2 = torch.randn_like(y1)
     for s in range(x.size(1)):
         z = model(BracketedSequence(x, s, 1))
     y2 = torch.randn_like(y1)
     for s in range(x.size(1)):
         z = model(BracketedSequence(x, s, 1))
-        y2[:, s] = z.x[:, s]
+        y2[:, s] = z.slice()
 
 
-    # print(y1.max(dim = 2).values)
-    # print(y2.max(dim = 2).values)
     print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
 
 ######################################################################
     print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
 
 ######################################################################