Oups
[picoclvr.git] / mygpt.py
index 5ea4668..131c822 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -5,6 +5,11 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
+# This is an implementation from scratch of a "GPT", that is a model
+# composed of several causal self-attention blocks. It is equipped
+# with a caching mechanism for keys and values to avoid a O(N^3) cost
+# for auto-regression.
+
 import math
 
 import torch
@@ -14,19 +19,6 @@ from torch.nn import functional as F
 
 ######################################################################
 
-
-class WithResidual(nn.Module):
-    def __init__(self, *f):
-        super().__init__()
-        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
-
-    def forward(self, bs):
-        bs.x = bs.x + self.f(bs).x
-        return bs
-
-
-######################################################################
-
 # A BracketedSequence is a BxTx... tensor with a first and a nb time
 # steps to compute.
 
@@ -53,6 +45,9 @@ class BracketedSequence:
     def slice(self):
         return self.x[:, self.first : self.first + self.nb]
 
+    def complete(self):
+        return self.first == 0 and self.nb == self.x.size(1)
+
 
 ######################################################################
 
@@ -70,9 +65,19 @@ class CacheWrapper(nn.Module):
         else:
             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
 
-        bs.x = self.cache_y
+        return BracketedSequence(self.cache_y, bs.first, bs.nb)
 
-        return bs
+
+##############################
+
+
+class WithResidual(nn.Module):
+    def __init__(self, *f):
+        super().__init__()
+        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+    def forward(self, bs):
+        return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb)
 
 
 ##############################
@@ -103,9 +108,7 @@ class AddPositionalEncoding(nn.Module):
             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
         )
 
-        bs.x = self.cache_y
-
-        return bs
+        return BracketedSequence(self.cache_y, bs.first, bs.nb)
 
 
 ##############################
@@ -113,7 +116,13 @@ class AddPositionalEncoding(nn.Module):
 
 class QKVAttention(nn.Module):
     def __init__(
-        self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0
+        self,
+        dim_in,
+        dim_qk,
+        dim_v,
+        nb_heads=1,
+        causal=False,
+        attention_dropout=0.0,
     ):
         super().__init__()
 
@@ -122,6 +131,7 @@ class QKVAttention(nn.Module):
 
         self.causal = causal
         self.attention_dropout = attention_dropout
+        self.record_attention = False
 
         self.w_q = randw(nb_heads, dim_qk, dim_in)
         self.w_k = randw(nb_heads, dim_qk, dim_in)
@@ -131,6 +141,10 @@ class QKVAttention(nn.Module):
     def forward(self, bs_q):
         x_q = bs_q.x
 
+        assert (
+            self.causal or bs_q.complete()
+        ), "Partial evaluation is only possible for causal models"
+
         if bs_q.first == 0:
             self.cache_k = x_q.new_zeros(
                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
@@ -143,6 +157,7 @@ class QKVAttention(nn.Module):
         q = torch.einsum(
             "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q
         )
+
         self.cache_k[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
             "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_k
         )
@@ -168,6 +183,10 @@ class QKVAttention(nn.Module):
             )
 
         a = a.softmax(dim=3)
+
+        if self.record_attention:
+            self.a = a
+
         a = F.dropout(a, self.attention_dropout, self.training)
 
         y = torch.einsum(
@@ -176,9 +195,7 @@ class QKVAttention(nn.Module):
 
         self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o
 
-        bs_q.x = self.cache_y
-
-        return bs_q
+        return BracketedSequence(self.cache_y, bs_q.first, bs_q.nb)
 
 
 ##############################
@@ -197,7 +214,6 @@ class MyGPT(nn.Module):
         dropout=0.0,
         len_max=1e5,
     ):
-
         super().__init__()
 
         assert dim_model % nb_heads == 0
@@ -248,43 +264,83 @@ class MyGPT(nn.Module):
                     m.weight.fill_(1.0)
 
     def forward(self, bs):
-        bs.x = F.pad(bs.x, (1, -1))
+        # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
+        bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
         bs = self.embedding(bs)
         bs = self.trunk(bs)
         bs = self.readout(bs)
         return bs
 
+    # ar_mask is a tensor with 0s and 1s, of same shape as input, with
+    # 1s where tokens should be generated. The others are kept
+    # unchanged.
+
+    def masked_inplace_autoregression(
+        self,
+        input,
+        ar_mask,
+        deterministic_synthesis=False,
+        forbidden_tokens=None,
+        forced_biases=None,
+    ):
+        to_generate = (ar_mask.sum(0) > 0).nonzero()
+        if to_generate.min() > 0:
+            self(
+                BracketedSequence(input, 0, to_generate.min())
+            )  # Needed to initialize the model's cache
+        for s in range(to_generate.min(), to_generate.max() + 1):
+            output = self(BracketedSequence(input, s, 1)).x
+            logits = output[:, s]
+            if forbidden_tokens is not None:
+                logits = logits.masked_fill(forbidden_tokens, float("-inf"))
+            if forced_biases is not None:
+                logits = logits + forced_biases[None, :]
+            if deterministic_synthesis:
+                t_next = logits.argmax(1)
+            else:
+                dist = torch.distributions.categorical.Categorical(logits=logits)
+                t_next = dist.sample()
+            input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
+
+    def record_attention(self, v=True):
+        for m in self.modules():
+            if isinstance(m, QKVAttention):
+                m.record_attention = v
+
+    def retrieve_attention(self):
+        a = []
+        for m in self.modules():
+            if isinstance(m, QKVAttention):
+                a.append(m.a)
+        return a
+
 
 ######################################################################
 
 if __name__ == "__main__":
-
     print("Basic check.")
 
-    vocabulary_size = 10
-    x = torch.randint(vocabulary_size, (9, 7))
+    vocabulary_size = 3
+    x = torch.randint(vocabulary_size, (1, 5))
 
     model = MyGPT(
         vocabulary_size=vocabulary_size,
-        dim_model=18,
-        dim_keys=50,
-        dim_hidden=100,
+        dim_model=4,
+        dim_keys=2,
+        dim_hidden=2,
         nb_heads=2,
-        nb_blocks=1,
+        nb_blocks=2,
         dropout=0.1,
+        causal=True,
     )
 
     model.eval()
-
     y1 = model(BracketedSequence(x)).x
-
     y2 = torch.randn_like(y1)
     for s in range(x.size(1)):
         z = model(BracketedSequence(x, s, 1))
-        y2[:, s] = z.x[:, s]
+        y2[:, s] = z.slice()
 
-    # print(y1.max(dim = 2).values)
-    # print(y2.max(dim = 2).values)
     print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
 
 ######################################################################