+ order_output = order + 1
+ order_input = torch.cat(
+ (order.new_zeros(order.size(0), 1), order[:, :-1] + 1), 1
+ )
+
+ self.pe = torch.cat(
+ (
+ pe.gather(1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1))),
+ pe.gather(
+ 1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1))
+ ),
+ ),
+ 2,
+ )