Update
[beaver.git] / mygpt.py
index 311ff6b..4555b1e 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -106,20 +106,16 @@ class AddPositionalEncoding(nn.Module):
             )
 
             order_output = order + 1
-            order_input = torch.cat(
-                (order.new_zeros(order.size(0), 1), order[:, :-1] + 1), 1
-            )
+            order_input = F.pad(order + 1, (1, -1))
 
-            self.pe = torch.cat(
-                (
-                    pe.gather(1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1))),
-                    pe.gather(
-                        1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1))
-                    ),
-                ),
-                2,
+            pe_input = pe.gather(
+                1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1))
+            )
+            pe_output = pe.gather(
+                1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1))
             )
 
+            self.pe = torch.cat((pe_input, pe_output), 2)
             self.cache_y = bs.x.new(bs.x.size())
 
         self.cache_y[:, bs.first : bs.first + bs.nb] = (
@@ -136,13 +132,27 @@ class AddPositionalEncoding(nn.Module):
 
 class QKVAttention(nn.Module):
     def __init__(
-        self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0
+        self,
+        dim_in,
+        dim_qk,
+        dim_v,
+        nb_heads=1,
+        causal=False,
+        attention_dropout=0.0,
+        amm_generator=None,
     ):
         super().__init__()
 
         def randw(*d):
             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
 
+        if amm_generator is None:
+            self.amm_generator = (
+                lambda d: torch.arange(d)[:, None] < torch.arange(d)[None, :]
+            )
+        else:
+            self.amm_generator = amm_generator
+
         self.causal = causal
         self.attention_dropout = attention_dropout
 
@@ -179,10 +189,9 @@ class QKVAttention(nn.Module):
 
         if self.causal:
             if bs_q.first == 0:
-                self.cache_attzero = (
-                    torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
-                    < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
-                )
+                self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device)[
+                    None, None, :, :
+                ]
             a = a.masked_fill(
                 self.cache_attzero[
                     :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
@@ -219,6 +228,7 @@ class MyGPT(nn.Module):
         causal=False,
         dropout=0.0,
         len_max=1e5,
+        amm_generator=None,
     ):
         super().__init__()
 
@@ -242,6 +252,7 @@ class MyGPT(nn.Module):
                         nb_heads=nb_heads,
                         causal=causal,
                         attention_dropout=dropout,
+                        amm_generator=amm_generator,
                     ),
                 ),
                 WithResidual(