+ pe = (
+ torch.sin(
+ t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
+ )
+ .unsqueeze(0)
+ .expand(bs.x.size(0), -1, -1)
+ )
+
+ order_output = order + 1
+ order_input = F.pad(order + 1, (1, -1))
+
+ pe_input = pe.gather(
+ 1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1))
+ )
+ pe_output = pe.gather(
+ 1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1))