+
+ order_output = order + 1
+ order_input = F.pad(order + 1, (1, -1))
+
+ pe_input = pe.gather(
+ 1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1))
+ )
+ pe_output = pe.gather(
+ 1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1))
+ )
+
+ self.pe = torch.cat((pe_input, pe_output), 2)