From 1be1638f9906a1071dc82ebc6f35f8fc0eb91a3d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 18 Jul 2023 22:03:24 +0200 Subject: [PATCH] Update. --- tasks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tasks.py b/tasks.py index 332d6c5..5ac78cb 100755 --- a/tasks.py +++ b/tasks.py @@ -101,7 +101,7 @@ class ProblemLevel1(Problem): b = a + 1 + self.len_prompt sequences = torch.empty(nb, b + 1 + self.len_result, dtype=torch.int64) nb_operators = torch.randint(self.operators.size(0), (nb,)) - sequences[:, :a] = (nb_operators[:, None] / 10 ** torch.arange(a)) % 10 + sequences[:, :a] = (nb_operators[:, None] / 10 ** torch.arange(a-1,-1,-1)) % 10 sequences[:, a] = 10 sequences[:, a + 1 : b] = torch.randint(10, (nb, b - a - 1)) sequences[:, b] = 11 @@ -115,7 +115,7 @@ class ProblemLevel1(Problem): return sequences, ar_mask def seq2str(self, seq): - return "".join(self.id2char[x.item()] for x in seq) + return "".join("0123456789|>"[x.item()] for x in seq) #################### -- 2.20.1