)
order_output = order + 1
- order_input = torch.cat(
- (order.new_zeros(order.size(0), 1), order[:, :-1] + 1), 1
- )
+ order_input = F.pad(order + 1, (1, -1))
- self.pe = torch.cat(
- (
- pe.gather(1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1))),
- pe.gather(
- 1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1))
- ),
- ),
- 2,
+ pe_input = pe.gather(
+ 1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1))
+ )
+ pe_output = pe.gather(
+ 1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1))
)
+ self.pe = torch.cat((pe_input, pe_output), 2)
self.cache_y = bs.x.new(bs.x.size())
self.cache_y[:, bs.first : bs.first + bs.nb] = (