b = a + 1 + self.len_prompt
sequences = torch.empty(nb, b + 1 + self.len_result, dtype=torch.int64)
nb_operators = torch.randint(self.operators.size(0), (nb,))
- sequences[:, :a] = (nb_operators[:, None] / 10 ** torch.arange(a)) % 10
+ sequences[:, :a] = (nb_operators[:, None] / 10 ** torch.arange(a-1,-1,-1)) % 10
sequences[:, a] = 10
sequences[:, a + 1 : b] = torch.randint(10, (nb, b - a - 1))
sequences[:, b] = 11
return sequences, ar_mask
def seq2str(self, seq):
- return "".join(self.id2char[x.item()] for x in seq)
+ return "".join("0123456789|>"[x.item()] for x in seq)
####################