# [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
- def forward(self, bs, order=None):
+ def forward(self, bs, order): # NxTxD, T
if bs.first == 0:
- t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
- :, None
- ]
- j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
+ t = (
+ torch.arange(bs.x.size(1) + 1, dtype=bs.x.dtype, device=bs.x.device)[
+ :, None
+ ]
+ - 1
+ )
+ j = torch.arange(bs.x.size(2) // 2, dtype=bs.x.dtype, device=bs.x.device)[
None, :
]
k = j % 2
- self.pe = torch.sin(
- t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
+ pe = (
+ torch.sin(
+ t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
+ )
+ .unsqueeze(0)
+ .expand(bs.x.size(0), -1, -1)
)
- if order is not None:
- self.pe = self.pe.gather(1, order.unsqueeze(-1).expand_as(self.pe))
+ order_output = order + 1
+ order_input = torch.cat(
+ (order.new_zeros(order.size(0), 1), order[:, :-1] + 1), 1
+ )
+
+ self.pe = torch.cat(
+ (
+ pe.gather(1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1))),
+ pe.gather(
+ 1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1))
+ ),
+ ),
+ 2,
+ )
self.cache_y = bs.x.new(bs.x.size())
self.cache_y[:, bs.first : bs.first + bs.nb] = (
- bs.slice() + self.pe[bs.first : bs.first + bs.nb]
+ bs.slice() + self.pe[:, bs.first : bs.first + bs.nb]
)
bs.x = self.cache_y
def forward(self, bs, mode="standard", order=None):
bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
- if order is not None:
- order = F.pad(order + 1, (1, -1))
+ if order is None:
+ order = torch.arange(bs.x.size(1), device=bs.x.device)[None, :].expand_as(
+ bs.x
+ )
bs = self.embedding(bs)
bs = self.pe(bs, order)
r += [bs.slice()]
bs = BracketedSequence(torch.cat(r, -1))
else:
- raise ValueError
+ raise ValueError(f"{mode=}")
return bs