a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
sequences = torch.cat(
- (s, torch.full((nb, 1), 12), a1, torch.full((nb, 1), 12), a2), 1
+ (
+ s,
+ torch.full((nb, 1), 12),
+ a1,
+ torch.full((nb, 1), 12),
+ a2,
+ torch.full((nb, 1), 12),
+ ),
+ 1,
)
ar_mask = (sequences == 12).long()
ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
return sequences, ar_mask
def seq2str(self, seq):
- return "".join("0123456789+-|"[x.item()] for x in seq)
+ return "".join("0123456789-+|"[x.item()] for x in seq)
####################