- ar_mask = (sequences == 11).long()
- ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
- return sequences, ar_mask
-
- def seq2str(self, seq):
- return "".join("0123456789|>_"[x.item()] for x in seq)
-
-
-####################
-
-
-class ProblemLenId(Problem):
- def __init__(self, len_max=10):
- self.len_max = len_max
-
- def generate_sequences(self, nb):
- k = torch.arange(self.len_max * 3 + 3)[None, :]
- l = torch.randint(self.len_max, (2, nb))[:, :, None] + 1
- i = torch.randint(10, (2, nb))[:, :, None]
- a = l[0]
- b = l[0] + 1 + l[1]
- c = l[0] + 1 + l[1] + 1 + l[0]
- sequences = (
- (k < a) * i[0]
- + (k == a) * 10
- + (k > a) * (k < b) * i[1]
- + (k == b) * 11
- + (k > b) * (k < c) * i[1]
- + (k >= c) * 12
+ k2 = l.argmax(dim=1, keepdim=True)
+ m = (k != k2).long() * (k != k2 + self.len_targets - 1).long()
+ s = s * m + 11 * (1 - m)
+ a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
+ a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
+ sequences = torch.cat(
+ (
+ s,
+ torch.full((nb, 1), 12),
+ a1,
+ torch.full((nb, 1), 12),
+ a2,
+ torch.full((nb, 1), 12),
+ ),
+ 1,