+ self.seq[:, len_prompt] = 10
+
+ def generate_sequences(self, nb):
+ sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
+ ar_mask = (sequences == 10).long()
+ ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+ return sequences, ar_mask
+
+
+class ProblemLevel1(Problem):
+ def __init__(self, nb_operators=100, len_source=5, len_result=8):
+ self.len_source = len_source
+ self.len_result = len_result
+ self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
+ self.operators = F.one_hot(
+ torch.rand(nb_operators, len_result, len_source).argmax(-1),
+ num_classes=len_source,
+ )
+
+
+
+ def generate_sequences(self, nb):
+ nb_operators = torch.randint(self.operators.size(0), (nb,))
+ operators = self.operators[nb_operators]
+ nb_operators = (nb_operators[:, None] // 10 ** torch.arange(self.len_nb_operator-1,-1,-1)) % 10
+ marker1 = torch.full((nb,1),10)
+ source = torch.randint(10, (nb, self.len_source))
+ marker2 = torch.full((nb,1),11)
+ result = operators.bmm(source[:, :, None]).squeeze(-1)
+ print(f"{nb_operators.dtype=} {marker1.dtype=}")
+ sequences = torch.cat((nb_operators, marker1, source,marker2,result),1)
+ print(f"{sequences.size()=}")
+ ar_mask = (sequences == 11).long()
+ ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+ return sequences, ar_mask
+
+ def seq2str(self, seq):
+ return "".join("0123456789|>"[x.item()] for x in seq)
+
+
+####################
+
+
+class ProblemAddition(Problem):
+ def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
+ self.nb_digits = nb_digits
+ self.zero_padded = zero_padded
+ self.inverted_result = inverted_result
+ self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
+ self.id2char = dict([(n, c) for c, n in self.char2id.items()])
+
+ def tensorize(self, strings):
+ len_max = max([len(x) for x in strings])
+ return torch.cat(
+ [
+ torch.tensor(
+ [
+ [self.char2id[c] for c in s + "$" * (len_max - len(s))]
+ for s in strings
+ ]
+ )
+ ],
+ 0,
+ )