+ pe = (
+ torch.sin(
+ t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
+ )
+ .unsqueeze(0)
+ .expand(bs.x.size(0), -1, -1)
+ )
+
+ order_output = order + 1
+ order_input = torch.cat(
+ (order.new_zeros(order.size(0), 1), order[:, :-1] + 1), 1
+ )
+
+ self.pe = torch.cat(
+ (
+ pe.gather(1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1))),
+ pe.gather(
+ 1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1))
+ ),
+ ),
+ 2,