+ self.seq[:, len_prompt] = 10
+
+ def generate_sequences(self, nb):
+ sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
+ ar_mask = (sequences == 10).long()
+ ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+ return sequences, ar_mask
+
+
+class ProblemLevel1(Problem):
+ def __init__(self, nb_operators=100, len_prompt=5, len_result=8):
+ self.len_prompt = len_prompt
+ self.len_result = len_result
+ self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
+ self.operators = F.one_hot(
+ torch.rand(nb_operators, len_result, len_prompt).argmax(-1),
+ num_classes=len_prompt,
+ )
+
+ def generate_sequences(self, nb):
+ a = self.len_nb_operator
+ b = a + 1 + self.len_prompt
+ sequences = torch.empty(nb, b + 1 + self.len_result, dtype=torch.int64)
+ nb_operators = torch.randint(self.operators.size(0), (nb,))
+ sequences[:, :a] = (nb_operators[:, None] / 10 ** torch.arange(a-1,-1,-1)) % 10
+ sequences[:, a] = 10
+ sequences[:, a + 1 : b] = torch.randint(10, (nb, b - a - 1))
+ sequences[:, b] = 11
+
+ o = self.operators[nb_operators]
+ p = sequences[:, a + 1 : b]
+ print(f"{o.size()=} {p.size()=} {sequences[:,b+1:].size()=}")
+ sequences[:, b + 1 :] = o.bmm(p[:, :, None]).squeeze(-1)
+ ar_mask = (sequences == 11).long()
+ ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+ return sequences, ar_mask
+
+ def seq2str(self, seq):
+ return "".join("0123456789|>"[x.item()] for x in seq)
+
+
+####################
+
+
+class ProblemAddition(Problem):
+ def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
+ self.nb_digits = nb_digits
+ self.zero_padded = zero_padded
+ self.inverted_result = inverted_result
+ self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
+ self.id2char = dict([(n, c) for c, n in self.char2id.items()])
+
+ def tensorize(self, strings):
+ len_max = max([len(x) for x in strings])
+ return torch.cat(
+ [
+ torch.tensor(
+ [
+ [self.char2id[c] for c in s + "$" * (len_max - len(s))]
+ for s in strings
+ ]
+ )
+ ],
+ 0,
+ )