Added the (missing) W_o
[mygpt.git] / mygpt.py
index 7bf25b5..42960b1 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -41,31 +41,36 @@ class PositionalEncoding(nn.Module):
 ##############################
 
 class QKVAttention(nn.Module):
-    def __init__(self, dim_in, dim_qk, dim_v, nb_heads = 1, causal = False, attention_dropout = 0.0):
+    def __init__(self, dim_in, dim_qk, dim_v,
+                 nb_heads = 1, causal = False, attention_dropout = 0.0):
         super().__init__()
 
         def randw(*d):
             return nn.Parameter(torch.empty(*d).normal_(0, 1 / math.sqrt(d[-1])))
 
-        self.wq = randw(nb_heads, dim_qk, dim_in)
-        self.wk = randw(nb_heads, dim_qk, dim_in)
-        self.wv = randw(nb_heads, dim_v, dim_in)
+        self.w_q = randw(nb_heads, dim_qk, dim_in)
+        self.w_k = randw(nb_heads, dim_qk, dim_in)
+        self.w_v = randw(nb_heads, dim_v, dim_in)
+        self.w_o = randw(nb_heads, dim_in, dim_v)
         self.causal = causal
         self.attention_dropout = attention_dropout
 
-    def forward(self, x):
-        q = torch.einsum('ntc,hdc->nhtd', x, self.wq)
-        k = torch.einsum('ntc,hdc->nhtd', x, self.wk)
-        v = torch.einsum('ntc,hdc->nhtd', x, self.wv)
-        r = math.sqrt(q.size(3))
-        a = torch.einsum('nhtd,nhsd->nhts', q, k).div(r)
+    def forward(self, x_q, x_kv = None):
+        if x_kv is None: x_kv = x_q
+        q = torch.einsum('ntc,hdc->nhtd', x_q, self.w_q)
+        k = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_k)
+        v = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_v)
+        a = torch.einsum('nhtd,nhsd->nhts', q, k) / math.sqrt(q.size(3))
         if self.causal:
-            mask = torch.tril(q.new_ones(a.size(2), a.size(3)))[None, None, :, :] == 0
+            mask = torch.arange(a.size(2), device = q.device)[None, None, :, None] \
+                   < torch.arange(a.size(3), device = q.device)[None, None, None, :]
             a = a.masked_fill(mask, float('-inf'))
         a = a.softmax(dim = 3)
         a = F.dropout(a, self.attention_dropout, self.training)
         y = torch.einsum('nhts,nhsd->nhtd', a, v)
-        return y.permute(0, 2, 1, 3).flatten(2) # nhtd -> nt(hd)
+        y = torch.einsum('nhtd,hcd->ntc', y, self.w_o)
+
+        return y
 
 ##############################
 
@@ -119,3 +124,20 @@ class MyGPT(nn.Module):
         return x
 
 ######################################################################
+
+if __name__ == '__main__':
+    print('Basic check.')
+
+    vocabulary_size = 10
+    x = torch.randint(vocabulary_size, (25, 100))
+
+    model = MyGPT(
+        vocabulary_size = vocabulary_size,
+        dim_model = 16, dim_keys = 50, dim_hidden = 100,
+        nb_heads = 2, nb_blocks = 3,
+        dropout = 0.1
+    )
+
+    y = model(x)
+
+######################################################################