# [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
- def forward(self, bs):
+ def forward(self, bs, order=None):
if bs.first == 0:
t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
:, None
self.pe = torch.sin(
t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
)
+
+ if order is not None:
+ self.pe = self.pe.gather(1, order.unsqueeze(-1).expand_as(self.pe))
+
self.cache_y = bs.x.new(bs.x.size())
self.cache_y[:, bs.first : bs.first + bs.nb] = (
assert dim_model % nb_heads == 0
- self.embedding = nn.Sequential(
- CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
- AddPositionalEncoding(len_max),
+ self.embedding = CacheWrapper(
+ nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)
)
+ self.pe = AddPositionalEncoding(len_max)
trunk_blocks = []
m.bias.zero_()
m.weight.fill_(1.0)
- def forward(self, bs, mode="standard"):
- bs.x = F.pad(bs.x, (1, -1))
+ def forward(self, bs, mode="standard", order=None):
+ bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
+ if order is not None:
+ order = F.pad(order + 1, (1, -1))
bs = self.embedding(bs)
+ bs = self.pe(bs, order)
+
if mode == "standard":
bs = self.trunk(bs)
bs = self.readout(bs)