Update
[beaver.git] / mygpt.py
index 311ff6b..75adbf6 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -106,20 +106,16 @@ class AddPositionalEncoding(nn.Module):
             )
 
             order_output = order + 1
-            order_input = torch.cat(
-                (order.new_zeros(order.size(0), 1), order[:, :-1] + 1), 1
-            )
+            order_input = F.pad(order + 1, (1, -1))
 
-            self.pe = torch.cat(
-                (
-                    pe.gather(1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1))),
-                    pe.gather(
-                        1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1))
-                    ),
-                ),
-                2,
+            pe_input = pe.gather(
+                1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1))
+            )
+            pe_output = pe.gather(
+                1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1))
             )
 
+            self.pe = torch.cat((pe_input, pe_output), 2)
             self.cache_y = bs.x.new(bs.x.size())
 
         self.cache_y[:, bs.first : bs.first + bs.nb] = (