+ k2 = l.argmax(dim=1, keepdim=True)
+ m = (k != k2).long() * (k != k2 + self.len_targets - 1).long()
+ s = s * m + 11 * (1 - m)
+ a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
+ a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
+ sequences = torch.cat(
+ (
+ s,
+ torch.full((nb, 1), 12),
+ a1,
+ torch.full((nb, 1), 12),
+ a2,
+ torch.full((nb, 1), 12),
+ ),
+ 1,
+ )
+ ar_mask = (sequences == 12).long()