Update
[beaver.git] / mygpt.py
index df6eab6..4555b1e 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -85,22 +85,41 @@ class AddPositionalEncoding(nn.Module):
 
     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
 
-    def forward(self, bs):
+    def forward(self, bs, order):  # NxTxD, T
         if bs.first == 0:
-            t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
-                :, None
-            ]
-            j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
+            t = (
+                torch.arange(bs.x.size(1) + 1, dtype=bs.x.dtype, device=bs.x.device)[
+                    :, None
+                ]
+                - 1
+            )
+            j = torch.arange(bs.x.size(2) // 2, dtype=bs.x.dtype, device=bs.x.device)[
                 None, :
             ]
             k = j % 2
-            self.pe = torch.sin(
-                t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
+            pe = (
+                torch.sin(
+                    t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
+                )
+                .unsqueeze(0)
+                .expand(bs.x.size(0), -1, -1)
+            )
+
+            order_output = order + 1
+            order_input = F.pad(order + 1, (1, -1))
+
+            pe_input = pe.gather(
+                1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1))
+            )
+            pe_output = pe.gather(
+                1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1))
             )
+
+            self.pe = torch.cat((pe_input, pe_output), 2)
             self.cache_y = bs.x.new(bs.x.size())
 
         self.cache_y[:, bs.first : bs.first + bs.nb] = (
-            bs.slice() + self.pe[bs.first : bs.first + bs.nb]
+            bs.slice() + self.pe[:, bs.first : bs.first + bs.nb]
         )
 
         bs.x = self.cache_y
@@ -113,13 +132,27 @@ class AddPositionalEncoding(nn.Module):
 
 class QKVAttention(nn.Module):
     def __init__(
-        self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0
+        self,
+        dim_in,
+        dim_qk,
+        dim_v,
+        nb_heads=1,
+        causal=False,
+        attention_dropout=0.0,
+        amm_generator=None,
     ):
         super().__init__()
 
         def randw(*d):
             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
 
+        if amm_generator is None:
+            self.amm_generator = (
+                lambda d: torch.arange(d)[:, None] < torch.arange(d)[None, :]
+            )
+        else:
+            self.amm_generator = amm_generator
+
         self.causal = causal
         self.attention_dropout = attention_dropout
 
@@ -156,10 +189,9 @@ class QKVAttention(nn.Module):
 
         if self.causal:
             if bs_q.first == 0:
-                self.cache_attzero = (
-                    torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
-                    < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
-                )
+                self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device)[
+                    None, None, :, :
+                ]
             a = a.masked_fill(
                 self.cache_attzero[
                     :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
@@ -196,16 +228,16 @@ class MyGPT(nn.Module):
         causal=False,
         dropout=0.0,
         len_max=1e5,
+        amm_generator=None,
     ):
-
         super().__init__()
 
         assert dim_model % nb_heads == 0
 
-        self.embedding = nn.Sequential(
-            CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
-            AddPositionalEncoding(len_max),
+        self.embedding = CacheWrapper(
+            nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)
         )
+        self.pe = AddPositionalEncoding(len_max)
 
         trunk_blocks = []
 
@@ -220,6 +252,7 @@ class MyGPT(nn.Module):
                         nb_heads=nb_heads,
                         causal=causal,
                         attention_dropout=dropout,
+                        amm_generator=amm_generator,
                     ),
                 ),
                 WithResidual(
@@ -247,18 +280,34 @@ class MyGPT(nn.Module):
                     m.bias.zero_()
                     m.weight.fill_(1.0)
 
-    def forward(self, bs):
-        bs.x = F.pad(bs.x, (1, -1))
+    def forward(self, bs, mode="standard", order=None):
+        bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
+        if order is None:
+            order = torch.arange(bs.x.size(1), device=bs.x.device)[None, :].expand_as(
+                bs.x
+            )
         bs = self.embedding(bs)
-        bs = self.trunk(bs)
-        bs = self.readout(bs)
+        bs = self.pe(bs, order)
+
+        if mode == "standard":
+            bs = self.trunk(bs)
+            bs = self.readout(bs)
+        elif mode == "head":
+            bs = self.trunk(bs)
+        elif mode == "deep":
+            r = []
+            for l in self.trunk:
+                bs = l(bs)
+                r += [bs.slice()]
+            bs = BracketedSequence(torch.cat(r, -1))
+        else:
+            raise ValueError(f"{mode=}")
         return bs
 
 
 ######################################################################
 
 if __name__ == "__main__":
-
     print("Basic check.")
 
     vocabulary_size = 10