- self.pe = torch.cat(
- (
- pe.gather(1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1))),
- pe.gather(
- 1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1))
- ),
- ),
- 2,
+ pe_input = pe.gather(
+ 1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1))
+ )
+ pe_output = pe.gather(
+ 1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1))