+ return "".join("0123456789|>"[x.item()] for x in seq)
+
+
+class ProblemLevel2(Problem):
+ def __init__(self, len_source=5, len_result=8):
+ self.len_source = len_source
+ self.len_result = len_result
+
+ def generate_sequences(self, nb):
+ operators = F.one_hot(
+ torch.rand(nb, self.len_result, self.len_source).argmax(-1),
+ num_classes=self.len_source,
+ )
+ source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
+ # source1 = torch.randint(10, (nb, self.len_source))
+ marker1 = torch.full((nb, 1), 10)
+ result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
+ marker2 = torch.full((nb, 1), 11)
+ source2 = torch.randint(10, (nb, self.len_source))
+ marker3 = torch.full((nb, 1), 12)
+ result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
+
+ sequences = torch.cat(
+ (source1, marker1, result1, marker2, source2, marker3, result2), 1
+ )
+ ar_mask = (sequences == 12).long()
+ ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+ return sequences, ar_mask
+
+ def seq2str(self, seq):
+ return "".join("0123456789>|~"[x.item()] for x in seq)