-######################################################################
-
-
-class Problem:
- def generate_sequences(self, nb):
- pass
-
- def seq2str(self, seq):
- return "[NOT IMPLEMENTED]"
-
-
-####################
-
-
-class ProblemLevel0(Problem):
- def __init__(self, nb_sentences=100, len_prompt=5, len_result=5):
- self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
- self.seq[:, len_prompt] = 10
-
- def generate_sequences(self, nb):
- sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
- ar_mask = (sequences == 10).long()
- ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
- return sequences, ar_mask
-
-
-class ProblemLevel1(Problem):
- def __init__(self, nb_operators=100, len_source=5, len_result=8):
- self.len_source = len_source
- self.len_result = len_result
- self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
- self.operators = F.one_hot(
- torch.rand(nb_operators, len_result, len_source).argmax(-1),
- num_classes=len_source,
- )
-
- def generate_sequences(self, nb):
- nb_operators = torch.randint(self.operators.size(0), (nb,))
- operators = self.operators[nb_operators]
- nb_operators = (
- nb_operators[:, None]
- // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
- ) % 10
- marker1 = torch.full((nb, 1), 10)
- # source = torch.randint(10, (nb, self.len_source))
- source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
- marker2 = torch.full((nb, 1), 11)
- result = operators.bmm(source[:, :, None]).squeeze(-1)
- sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
- ar_mask = (sequences == 11).long()
- ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
- return sequences, ar_mask
-
- def seq2str(self, seq):
- return "".join("0123456789|>"[x.item()] for x in seq)
-
-
-class ProblemLevel2(Problem):
- def __init__(self, len_source=5, len_result=8):
- self.len_source = len_source
- self.len_result = len_result
-
- def generate_sequences(self, nb):
- operators = F.one_hot(
- torch.rand(nb, self.len_result, self.len_source).argmax(-1),
- num_classes=self.len_source,
- )
- source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
- marker1 = torch.full((nb, 1), 10)
- result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
- marker2 = torch.full((nb, 1), 11)
- source2 = torch.randint(10, (nb, self.len_source))
- marker3 = torch.full((nb, 1), 12)
- result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
-
- sequences = torch.cat(
- (source1, marker1, result1, marker2, source2, marker3, result2), 1
- )
- ar_mask = (sequences == 12).long()
- ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
- return sequences, ar_mask
-
- def seq2str(self, seq):
- return "".join("0123456789>|~"[x.item()] for x in seq)
-
-