t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None]
j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :]
k = j%2
- return x + torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)[None, :, :]
+ pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)
+ return x + pe # Let broadcasting to its job
##############################