Update
[beaver.git] / mygpt.py
index 232b604..311ff6b 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -85,26 +85,45 @@ class AddPositionalEncoding(nn.Module):
 
     # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
 
-    def forward(self, bs, order=None):
+    def forward(self, bs, order):  # NxTxD, T
         if bs.first == 0:
-            t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
-                :, None
-            ]
-            j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
+            t = (
+                torch.arange(bs.x.size(1) + 1, dtype=bs.x.dtype, device=bs.x.device)[
+                    :, None
+                ]
+                - 1
+            )
+            j = torch.arange(bs.x.size(2) // 2, dtype=bs.x.dtype, device=bs.x.device)[
                 None, :
             ]
             k = j % 2
-            self.pe = torch.sin(
-                t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
+            pe = (
+                torch.sin(
+                    t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
+                )
+                .unsqueeze(0)
+                .expand(bs.x.size(0), -1, -1)
             )
 
-            if order is not None:
-                self.pe = self.pe.gather(1, order.unsqueeze(-1).expand_as(self.pe))
+            order_output = order + 1
+            order_input = torch.cat(
+                (order.new_zeros(order.size(0), 1), order[:, :-1] + 1), 1
+            )
+
+            self.pe = torch.cat(
+                (
+                    pe.gather(1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1))),
+                    pe.gather(
+                        1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1))
+                    ),
+                ),
+                2,
+            )
 
             self.cache_y = bs.x.new(bs.x.size())
 
         self.cache_y[:, bs.first : bs.first + bs.nb] = (
-            bs.slice() + self.pe[bs.first : bs.first + bs.nb]
+            bs.slice() + self.pe[:, bs.first : bs.first + bs.nb]
         )
 
         bs.x = self.cache_y
@@ -252,8 +271,10 @@ class MyGPT(nn.Module):
 
     def forward(self, bs, mode="standard", order=None):
         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
-        if order is not None:
-            order = F.pad(order + 1, (1, -1))
+        if order is None:
+            order = torch.arange(bs.x.size(1), device=bs.x.device)[None, :].expand_as(
+                bs.x
+            )
         bs = self.embedding(bs)
         bs = self.pe(bs, order)
 
@@ -269,7 +290,7 @@ class MyGPT(nn.Module):
                 r += [bs.slice()]
             bs = BracketedSequence(torch.cat(r, -1))
         else:
-            raise ValueError
+            raise ValueError(f"{mode=}")
         return bs